[flang-commits] [flang] Reland "[flang] Inline hlfir.dot_product. (#123143)" (PR #123385)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Fri Jan 17 10:23:32 PST 2025


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/123385

This reverts commit afc43a7b626ae07f56e6534320e0b46d26070750.
+Fixed declaration of hlfir::genExtentsVector().


>From 0d1f9f9eb1aecf3ca7a359795d7cd4aa4c355393 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Fri, 17 Jan 2025 10:21:24 -0800
Subject: [PATCH] Reland "[flang] Inline hlfir.dot_product. (#123143)"

This reverts commit afc43a7b626ae07f56e6534320e0b46d26070750.
+Fixed declaration of hlfir::genExtentsVector().
---
 .../flang/Optimizer/Builder/HLFIRTools.h      |   6 +
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    |  12 +
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    | 279 ++++++++++--------
 .../simplify-hlfir-intrinsics-dotproduct.fir  | 144 +++++++++
 4 files changed, 326 insertions(+), 115 deletions(-)
 create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 6e85b8f4ddf86e..0684ad0f926ec9 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -513,6 +513,12 @@ genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
 Entity loadElementAt(mlir::Location loc, fir::FirOpBuilder &builder,
                      Entity entity, mlir::ValueRange oneBasedIndices);
 
+/// Return a vector of extents for the given entity.
+/// The function creates new operations, but tries to clean-up
+/// after itself.
+llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
+genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 5e5d0bbd681326..f71adf123511d5 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -1421,3 +1421,15 @@ hlfir::Entity hlfir::loadElementAt(mlir::Location loc,
   return loadTrivialScalar(loc, builder,
                            getElementAt(loc, builder, entity, oneBasedIndices));
 }
+
+llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
+hlfir::genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder,
+                        hlfir::Entity entity) {
+  entity = hlfir::derefPointersAndAllocatables(loc, builder, entity);
+  mlir::Value shape = hlfir::genShape(loc, builder, entity);
+  llvm::SmallVector<mlir::Value, Fortran::common::maxRank> extents =
+      hlfir::getExplicitExtentsFromShape(shape, builder);
+  if (shape.getUses().empty())
+    shape.getDefiningOp()->erase();
+  return extents;
+}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 0fe3620b7f1ae3..fe7ae0eeed3cc3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -37,6 +37,79 @@ static llvm::cl::opt<bool> forceMatmulAsElemental(
 
 namespace {
 
+// Helper class to generate operations related to computing
+// product of values.
+class ProductFactory {
+public:
+  ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder)
+      : loc(loc), builder(builder) {}
+
+  // Generate an update of the inner product value:
+  //   acc += v1 * v2, OR
+  //   acc += CONJ(v1) * v2, OR
+  //   acc ||= v1 && v2
+  //
+  // CONJ parameter specifies whether the first complex product argument
+  // needs to be conjugated.
+  template <bool CONJ = false>
+  mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1,
+                                   mlir::Value v2) {
+    mlir::Type resultType = acc.getType();
+    acc = castToProductType(acc, resultType);
+    v1 = castToProductType(v1, resultType);
+    v2 = castToProductType(v2, resultType);
+    mlir::Value result;
+    if (mlir::isa<mlir::FloatType>(resultType)) {
+      result = builder.create<mlir::arith::AddFOp>(
+          loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
+    } else if (mlir::isa<mlir::ComplexType>(resultType)) {
+      if constexpr (CONJ)
+        result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1);
+      else
+        result = v1;
+
+      result = builder.create<fir::AddcOp>(
+          loc, acc, builder.create<fir::MulcOp>(loc, result, v2));
+    } else if (mlir::isa<mlir::IntegerType>(resultType)) {
+      result = builder.create<mlir::arith::AddIOp>(
+          loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
+    } else if (mlir::isa<fir::LogicalType>(resultType)) {
+      result = builder.create<mlir::arith::OrIOp>(
+          loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
+    } else {
+      llvm_unreachable("unsupported type");
+    }
+
+    return builder.createConvert(loc, resultType, result);
+  }
+
+private:
+  mlir::Location loc;
+  fir::FirOpBuilder &builder;
+
+  mlir::Value castToProductType(mlir::Value value, mlir::Type type) {
+    if (mlir::isa<fir::LogicalType>(type))
+      return builder.createConvert(loc, builder.getIntegerType(1), value);
+
+    // TODO: the multiplications/additions by/of zero resulting from
+    // complex * real are optimized by LLVM under -fno-signed-zeros
+    // -fno-honor-nans.
+    // We can make them disappear by default if we:
+    //   * either expand the complex multiplication into real
+    //     operations, OR
+    //   * set nnan nsz fast-math flags to the complex operations.
+    if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
+      mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
+      fir::factory::Complex helper(builder, loc);
+      mlir::Type partType = helper.getComplexPartType(type);
+      return helper.insertComplexPart(zeroCmplx,
+                                      castToProductType(value, partType),
+                                      /*isImagPart=*/false);
+    }
+    return builder.createConvert(loc, type, value);
+  }
+};
+
 class TransposeAsElementalConversion
     : public mlir::OpRewritePattern<hlfir::TransposeOp> {
 public:
@@ -90,11 +163,8 @@ class TransposeAsElementalConversion
   static mlir::Value genResultShape(mlir::Location loc,
                                     fir::FirOpBuilder &builder,
                                     hlfir::Entity array) {
-    mlir::Value inShape = hlfir::genShape(loc, builder, array);
-    llvm::SmallVector<mlir::Value> inExtents =
-        hlfir::getExplicitExtentsFromShape(inShape, builder);
-    if (inShape.getUses().empty())
-      inShape.getDefiningOp()->erase();
+    llvm::SmallVector<mlir::Value, 2> inExtents =
+        hlfir::genExtentsVector(loc, builder, array);
 
     // transpose indices
     assert(inExtents.size() == 2 && "checked in TransposeOp::validate");
@@ -137,7 +207,7 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
     mlir::Value resultShape, dimExtent;
     llvm::SmallVector<mlir::Value> arrayExtents;
     if (isTotalReduction)
-      arrayExtents = genArrayExtents(loc, builder, array);
+      arrayExtents = hlfir::genExtentsVector(loc, builder, array);
     else
       std::tie(resultShape, dimExtent) =
           genResultShapeForPartialReduction(loc, builder, array, dimVal);
@@ -163,7 +233,8 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       // If DIM is not present, do total reduction.
 
       // Initial value for the reduction.
-      mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);
+      mlir::Value reductionInitValue =
+          fir::factory::createZeroValue(builder, loc, elementType);
 
       // The reduction loop may be unordered if FastMathFlags::reassoc
       // transformations are allowed. The integer reduction is always
@@ -264,17 +335,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
   }
 
 private:
-  static llvm::SmallVector<mlir::Value>
-  genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
-                  hlfir::Entity array) {
-    mlir::Value inShape = hlfir::genShape(loc, builder, array);
-    llvm::SmallVector<mlir::Value> inExtents =
-        hlfir::getExplicitExtentsFromShape(inShape, builder);
-    if (inShape.getUses().empty())
-      inShape.getDefiningOp()->erase();
-    return inExtents;
-  }
-
   // Return fir.shape specifying the shape of the result
   // of a SUM reduction with DIM=dimVal. The second return value
   // is the extent of the DIM dimension.
@@ -283,7 +343,7 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
                                     fir::FirOpBuilder &builder,
                                     hlfir::Entity array, int64_t dimVal) {
     llvm::SmallVector<mlir::Value> inExtents =
-        genArrayExtents(loc, builder, array);
+        hlfir::genExtentsVector(loc, builder, array);
     assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
            "DIM must be present and a positive constant not exceeding "
            "the array's rank");
@@ -293,26 +353,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
     return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
   }
 
-  // Generate the initial value for a SUM reduction with the given
-  // data type.
-  static mlir::Value genInitValue(mlir::Location loc,
-                                  fir::FirOpBuilder &builder,
-                                  mlir::Type elementType) {
-    if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(loc, elementType,
-                                        llvm::APFloat::getZero(sem));
-    } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
-      mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
-      return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
-                                                               initValue);
-    } else if (mlir::isa<mlir::IntegerType>(elementType)) {
-      return builder.createIntegerConstant(loc, elementType, 0);
-    }
-
-    llvm_unreachable("unsupported SUM reduction type");
-  }
-
   // Generate scalar addition of the two values (of the same data type).
   static mlir::Value genScalarAdd(mlir::Location loc,
                                   fir::FirOpBuilder &builder,
@@ -570,16 +610,10 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
   static std::tuple<mlir::Value, mlir::Value>
   genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
                  hlfir::Entity input1, hlfir::Entity input2) {
-    mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
-    llvm::SmallVector<mlir::Value> input1Extents =
-        hlfir::getExplicitExtentsFromShape(input1Shape, builder);
-    if (input1Shape.getUses().empty())
-      input1Shape.getDefiningOp()->erase();
-    mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
-    llvm::SmallVector<mlir::Value> input2Extents =
-        hlfir::getExplicitExtentsFromShape(input2Shape, builder);
-    if (input2Shape.getUses().empty())
-      input2Shape.getDefiningOp()->erase();
+    llvm::SmallVector<mlir::Value, 2> input1Extents =
+        hlfir::genExtentsVector(loc, builder, input1);
+    llvm::SmallVector<mlir::Value, 2> input2Extents =
+        hlfir::genExtentsVector(loc, builder, input2);
 
     llvm::SmallVector<mlir::Value, 2> newExtents;
     mlir::Value innerProduct1Extent, innerProduct2Extent;
@@ -627,60 +661,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             innerProductExtent[0]};
   }
 
-  static mlir::Value castToProductType(mlir::Location loc,
-                                       fir::FirOpBuilder &builder,
-                                       mlir::Value value, mlir::Type type) {
-    if (mlir::isa<fir::LogicalType>(type))
-      return builder.createConvert(loc, builder.getIntegerType(1), value);
-
-    // TODO: the multiplications/additions by/of zero resulting from
-    // complex * real are optimized by LLVM under -fno-signed-zeros
-    // -fno-honor-nans.
-    // We can make them disappear by default if we:
-    //   * either expand the complex multiplication into real
-    //     operations, OR
-    //   * set nnan nsz fast-math flags to the complex operations.
-    if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
-      mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
-      fir::factory::Complex helper(builder, loc);
-      mlir::Type partType = helper.getComplexPartType(type);
-      return helper.insertComplexPart(
-          zeroCmplx, castToProductType(loc, builder, value, partType),
-          /*isImagPart=*/false);
-    }
-    return builder.createConvert(loc, type, value);
-  }
-
-  // Generate an update of the inner product value:
-  //   acc += v1 * v2, OR
-  //   acc ||= v1 && v2
-  static mlir::Value genAccumulateProduct(mlir::Location loc,
-                                          fir::FirOpBuilder &builder,
-                                          mlir::Type resultType,
-                                          mlir::Value acc, mlir::Value v1,
-                                          mlir::Value v2) {
-    acc = castToProductType(loc, builder, acc, resultType);
-    v1 = castToProductType(loc, builder, v1, resultType);
-    v2 = castToProductType(loc, builder, v2, resultType);
-    mlir::Value result;
-    if (mlir::isa<mlir::FloatType>(resultType))
-      result = builder.create<mlir::arith::AddFOp>(
-          loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
-    else if (mlir::isa<mlir::ComplexType>(resultType))
-      result = builder.create<fir::AddcOp>(
-          loc, acc, builder.create<fir::MulcOp>(loc, v1, v2));
-    else if (mlir::isa<mlir::IntegerType>(resultType))
-      result = builder.create<mlir::arith::AddIOp>(
-          loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
-    else if (mlir::isa<fir::LogicalType>(resultType))
-      result = builder.create<mlir::arith::OrIOp>(
-          loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
-    else
-      llvm_unreachable("unsupported type");
-
-    return builder.createConvert(loc, resultType, result);
-  }
-
   static mlir::LogicalResult
   genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
                       hlfir::Entity result, mlir::Value resultShape,
@@ -748,9 +728,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, {I, K});
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, {K, J});
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, resultElementValue,
-            lhsElementValue, rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                resultElementValue, lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
         return {};
       };
@@ -785,9 +765,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, {J, K});
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, {K});
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, resultElementValue,
-            lhsElementValue, rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                resultElementValue, lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
         return {};
       };
@@ -817,9 +797,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, {K});
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, {K, J});
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, resultElementValue,
-            lhsElementValue, rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                resultElementValue, lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
         return {};
       };
@@ -885,9 +865,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, reductionArgs[0], lhsElementValue,
-            rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                reductionArgs[0], lhsElementValue, rhsElementValue);
         return {productValue};
       };
       llvm::SmallVector<mlir::Value, 1> innerProductValue =
@@ -904,6 +884,73 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
   }
 };
 
+class DotProductConversion
+    : public mlir::OpRewritePattern<hlfir::DotProductOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::DotProductOp product,
+                  mlir::PatternRewriter &rewriter) const override {
+    hlfir::Entity op = hlfir::Entity{product};
+    if (!op.isScalar())
+      return rewriter.notifyMatchFailure(product, "produces non-scalar result");
+
+    mlir::Location loc = product.getLoc();
+    fir::FirOpBuilder builder{rewriter, product.getOperation()};
+    hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
+    hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
+    mlir::Type resultElementType = product.getType();
+    bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
+                       mlir::isa<fir::LogicalType>(resultElementType) ||
+                       static_cast<bool>(builder.getFastMathFlags() &
+                                         mlir::arith::FastMathFlags::reassoc);
+
+    mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);
+
+    auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                       mlir::ValueRange oneBasedIndices,
+                       mlir::ValueRange reductionArgs)
+        -> llvm::SmallVector<mlir::Value, 1> {
+      hlfir::Entity lhsElementValue =
+          hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
+      hlfir::Entity rhsElementValue =
+          hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
+      mlir::Value productValue =
+          ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
+              reductionArgs[0], lhsElementValue, rhsElementValue);
+      return {productValue};
+    };
+
+    mlir::Value initValue =
+        fir::factory::createZeroValue(builder, loc, resultElementType);
+
+    llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
+        loc, builder, {extent},
+        /*reductionInits=*/{initValue}, genBody, isUnordered);
+
+    rewriter.replaceOp(product, result[0]);
+    return mlir::success();
+  }
+
+private:
+  static mlir::Value genProductExtent(mlir::Location loc,
+                                      fir::FirOpBuilder &builder,
+                                      hlfir::Entity input1,
+                                      hlfir::Entity input2) {
+    llvm::SmallVector<mlir::Value, 1> input1Extents =
+        hlfir::genExtentsVector(loc, builder, input1);
+    llvm::SmallVector<mlir::Value, 1> input2Extents =
+        hlfir::genExtentsVector(loc, builder, input2);
+
+    assert(input1Extents.size() == 1 && input2Extents.size() == 1 &&
+           "hlfir.dot_product arguments must be vectors");
+    llvm::SmallVector<mlir::Value, 1> extent =
+        fir::factory::deduceOptimalExtents(input1Extents, input2Extents);
+    return extent[0];
+  }
+};
+
 class SimplifyHLFIRIntrinsics
     : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
 public:
@@ -939,6 +986,8 @@ class SimplifyHLFIRIntrinsics
     if (forceMatmulAsElemental || this->allowNewSideEffects)
       patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
 
+    patterns.insert<DotProductConversion>(context);
+
     if (mlir::failed(mlir::applyPatternsGreedily(
             getOperation(), std::move(patterns), config))) {
       mlir::emitError(getOperation()->getLoc(),
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir
new file mode 100644
index 00000000000000..f59b1422dbc849
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir
@@ -0,0 +1,144 @@
+// Test hlfir.dot_product simplification to a reduction loop:
+// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
+
+func.func @dot_product_integer(%arg0: !hlfir.expr<?xi16>, %arg1: !hlfir.expr<?xi32>) -> i32 {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xi16>, !hlfir.expr<?xi32>) -> i32
+  return %res : i32
+}
+// CHECK-LABEL:   func.func @dot_product_integer(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !hlfir.expr<?xi16>,
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !hlfir.expr<?xi32>) -> i32 {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xi16>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (i32) {
+// CHECK:             %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xi16>, index) -> i16
+// CHECK:             %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xi32>, index) -> i32
+// CHECK:             %[[VAL_11:.*]] = fir.convert %[[VAL_9]] : (i16) -> i32
+// CHECK:             %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_10]] : i32
+// CHECK:             %[[VAL_13:.*]] = arith.addi %[[VAL_8]], %[[VAL_12]] : i32
+// CHECK:             fir.result %[[VAL_13]] : i32
+// CHECK:           }
+// CHECK:           return %[[VAL_6]] : i32
+// CHECK:         }
+
+func.func @dot_product_real(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xf16>) -> f32 {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xf16>) -> f32
+  return %res : f32
+}
+// CHECK-LABEL:   func.func @dot_product_real(
+// CHECK-SAME:                                %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
+// CHECK-SAME:                                %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> f32 {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (f32) {
+// CHECK:             %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xf32>, index) -> f32
+// CHECK:             %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xf16>, index) -> f16
+// CHECK:             %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f16) -> f32
+// CHECK:             %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : f32
+// CHECK:             %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_12]] : f32
+// CHECK:             fir.result %[[VAL_13]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_6]] : f32
+// CHECK:         }
+
+func.func @dot_product_complex(%arg0: !hlfir.expr<?xcomplex<f32>>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xcomplex<f32>>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
+  return %res : complex<f32>
+}
+// CHECK-LABEL:   func.func @dot_product_complex(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !hlfir.expr<?xcomplex<f32>>,
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xcomplex<f32>>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.undefined complex<f32>
+// CHECK:           %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
+// CHECK:             %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f32>>, index) -> complex<f32>
+// CHECK:             %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
+// CHECK:             %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
+// CHECK:             %[[VAL_15:.*]] = fir.extract_value %[[VAL_12]], [1 : index] : (complex<f32>) -> f32
+// CHECK:             %[[VAL_16:.*]] = arith.negf %[[VAL_15]] : f32
+// CHECK:             %[[VAL_17:.*]] = fir.insert_value %[[VAL_12]], %[[VAL_16]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_18:.*]] = fir.mulc %[[VAL_17]], %[[VAL_14]] : complex<f32>
+// CHECK:             %[[VAL_19:.*]] = fir.addc %[[VAL_11]], %[[VAL_18]] : complex<f32>
+// CHECK:             fir.result %[[VAL_19]] : complex<f32>
+// CHECK:           }
+// CHECK:           return %[[VAL_9]] : complex<f32>
+// CHECK:         }
+
+func.func @dot_product_real_complex(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
+  return %res : complex<f32>
+}
+// CHECK-LABEL:   func.func @dot_product_real_complex(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
+// CHECK-SAME:                                        %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.undefined complex<f32>
+// CHECK:           %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
+// CHECK:             %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xf32>, index) -> f32
+// CHECK:             %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
+// CHECK:             %[[VAL_14:.*]] = fir.undefined complex<f32>
+// CHECK:             %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_17:.*]] = fir.insert_value %[[VAL_16]], %[[VAL_12]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_18:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
+// CHECK:             %[[VAL_19:.*]] = fir.extract_value %[[VAL_17]], [1 : index] : (complex<f32>) -> f32
+// CHECK:             %[[VAL_20:.*]] = arith.negf %[[VAL_19]] : f32
+// CHECK:             %[[VAL_21:.*]] = fir.insert_value %[[VAL_17]], %[[VAL_20]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_22:.*]] = fir.mulc %[[VAL_21]], %[[VAL_18]] : complex<f32>
+// CHECK:             %[[VAL_23:.*]] = fir.addc %[[VAL_11]], %[[VAL_22]] : complex<f32>
+// CHECK:             fir.result %[[VAL_23]] : complex<f32>
+// CHECK:           }
+// CHECK:           return %[[VAL_9]] : complex<f32>
+// CHECK:         }
+
+func.func @dot_product_logical(%arg0: !hlfir.expr<?x!fir.logical<1>>, %arg1: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?x!fir.logical<1>>, !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+  return %res : !fir.logical<4>
+}
+// CHECK-LABEL:   func.func @dot_product_logical(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !hlfir.expr<?x!fir.logical<1>>,
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant false
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x!fir.logical<1>>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4>
+// CHECK:           %[[VAL_7:.*]] = fir.do_loop %[[VAL_8:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (!fir.logical<4>) {
+// CHECK:             %[[VAL_10:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<1>>, index) -> !fir.logical<1>
+// CHECK:             %[[VAL_11:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
+// CHECK:             %[[VAL_12:.*]] = fir.convert %[[VAL_9]] : (!fir.logical<4>) -> i1
+// CHECK:             %[[VAL_13:.*]] = fir.convert %[[VAL_10]] : (!fir.logical<1>) -> i1
+// CHECK:             %[[VAL_14:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<4>) -> i1
+// CHECK:             %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
+// CHECK:             %[[VAL_16:.*]] = arith.ori %[[VAL_12]], %[[VAL_15]] : i1
+// CHECK:             %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i1) -> !fir.logical<4>
+// CHECK:             fir.result %[[VAL_17]] : !fir.logical<4>
+// CHECK:           }
+// CHECK:           return %[[VAL_7]] : !fir.logical<4>
+// CHECK:         }
+
+func.func @dot_product_known_dim(%arg0: !hlfir.expr<10xf32>, %arg1: !hlfir.expr<?xi16>) -> f32 {
+  %res1 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<10xf32>, !hlfir.expr<?xi16>) -> f32
+  %res2 = hlfir.dot_product %arg1 %arg0 : (!hlfir.expr<?xi16>, !hlfir.expr<10xf32>) -> f32
+  %res = arith.addf %res1, %res2 : f32
+  return %res : f32
+}
+// CHECK-LABEL:   func.func @dot_product_known_dim(
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 10 : index
+// CHECK:           fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]
+// CHECK:           fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]



More information about the flang-commits mailing list