[flang-commits] [flang] [flang] Inline hlfir.dot_product. (PR #123143)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Jan 15 15:30:53 PST 2025


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/123143

Some good results for induct2, where dot_product is applied
to a vector of unknow size and a known 3-element vector:
the inlining ends up generating a 3-iteration loop, which
is then fully unrolled. With late FIR simplification
it is not happening even when the simplified intrinsics
implementation is inlined by LLVM (because the loop bounds
are not known).

This change just follows the current approach to expose
the loops for later worksharing application.


>From bfc3d7d9f33660d9cf7ddf958c48f0b87e088033 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 15 Jan 2025 12:00:20 -0800
Subject: [PATCH 1/2] [NFC] Clean-up product code generation.

---
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    | 174 +++++++++---------
 1 file changed, 87 insertions(+), 87 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 0fe3620b7f1ae3..f9dd6182cbde70 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -37,6 +37,79 @@ static llvm::cl::opt<bool> forceMatmulAsElemental(
 
 namespace {
 
+// Helper class to generate operations related to computing
+// product of values.
+class ProductFactory {
+public:
+  ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder)
+      : loc(loc), builder(builder) {}
+
+  // Generate an update of the inner product value:
+  //   acc += v1 * v2, OR
+  //   acc += CONJ(v1) * v2, OR
+  //   acc ||= v1 && v2
+  //
+  // CONJ parameter specifies whether the first complex product argument
+  // needs to be conjugated.
+  template <bool CONJ = false>
+  mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1,
+                                   mlir::Value v2) {
+    mlir::Type resultType = acc.getType();
+    acc = castToProductType(acc, resultType);
+    v1 = castToProductType(v1, resultType);
+    v2 = castToProductType(v2, resultType);
+    mlir::Value result;
+    if (mlir::isa<mlir::FloatType>(resultType)) {
+      result = builder.create<mlir::arith::AddFOp>(
+          loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
+    } else if (mlir::isa<mlir::ComplexType>(resultType)) {
+      if constexpr (CONJ)
+        result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1);
+      else
+        result = v1;
+
+      result = builder.create<fir::AddcOp>(
+          loc, acc, builder.create<fir::MulcOp>(loc, result, v2));
+    } else if (mlir::isa<mlir::IntegerType>(resultType)) {
+      result = builder.create<mlir::arith::AddIOp>(
+          loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
+    } else if (mlir::isa<fir::LogicalType>(resultType)) {
+      result = builder.create<mlir::arith::OrIOp>(
+          loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
+    } else {
+      llvm_unreachable("unsupported type");
+    }
+
+    return builder.createConvert(loc, resultType, result);
+  }
+
+private:
+  mlir::Location loc;
+  fir::FirOpBuilder &builder;
+
+  mlir::Value castToProductType(mlir::Value value, mlir::Type type) {
+    if (mlir::isa<fir::LogicalType>(type))
+      return builder.createConvert(loc, builder.getIntegerType(1), value);
+
+    // TODO: the multiplications/additions by/of zero resulting from
+    // complex * real are optimized by LLVM under -fno-signed-zeros
+    // -fno-honor-nans.
+    // We can make them disappear by default if we:
+    //   * either expand the complex multiplication into real
+    //     operations, OR
+    //   * set nnan nsz fast-math flags to the complex operations.
+    if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
+      mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
+      fir::factory::Complex helper(builder, loc);
+      mlir::Type partType = helper.getComplexPartType(type);
+      return helper.insertComplexPart(zeroCmplx,
+                                      castToProductType(value, partType),
+                                      /*isImagPart=*/false);
+    }
+    return builder.createConvert(loc, type, value);
+  }
+};
+
 class TransposeAsElementalConversion
     : public mlir::OpRewritePattern<hlfir::TransposeOp> {
 public:
@@ -163,7 +236,8 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       // If DIM is not present, do total reduction.
 
       // Initial value for the reduction.
-      mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);
+      mlir::Value reductionInitValue =
+          fir::factory::createZeroValue(builder, loc, elementType);
 
       // The reduction loop may be unordered if FastMathFlags::reassoc
       // transformations are allowed. The integer reduction is always
@@ -293,26 +367,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
     return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
   }
 
-  // Generate the initial value for a SUM reduction with the given
-  // data type.
-  static mlir::Value genInitValue(mlir::Location loc,
-                                  fir::FirOpBuilder &builder,
-                                  mlir::Type elementType) {
-    if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(loc, elementType,
-                                        llvm::APFloat::getZero(sem));
-    } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
-      mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
-      return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
-                                                               initValue);
-    } else if (mlir::isa<mlir::IntegerType>(elementType)) {
-      return builder.createIntegerConstant(loc, elementType, 0);
-    }
-
-    llvm_unreachable("unsupported SUM reduction type");
-  }
-
   // Generate scalar addition of the two values (of the same data type).
   static mlir::Value genScalarAdd(mlir::Location loc,
                                   fir::FirOpBuilder &builder,
@@ -627,60 +681,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             innerProductExtent[0]};
   }
 
-  static mlir::Value castToProductType(mlir::Location loc,
-                                       fir::FirOpBuilder &builder,
-                                       mlir::Value value, mlir::Type type) {
-    if (mlir::isa<fir::LogicalType>(type))
-      return builder.createConvert(loc, builder.getIntegerType(1), value);
-
-    // TODO: the multiplications/additions by/of zero resulting from
-    // complex * real are optimized by LLVM under -fno-signed-zeros
-    // -fno-honor-nans.
-    // We can make them disappear by default if we:
-    //   * either expand the complex multiplication into real
-    //     operations, OR
-    //   * set nnan nsz fast-math flags to the complex operations.
-    if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
-      mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
-      fir::factory::Complex helper(builder, loc);
-      mlir::Type partType = helper.getComplexPartType(type);
-      return helper.insertComplexPart(
-          zeroCmplx, castToProductType(loc, builder, value, partType),
-          /*isImagPart=*/false);
-    }
-    return builder.createConvert(loc, type, value);
-  }
-
-  // Generate an update of the inner product value:
-  //   acc += v1 * v2, OR
-  //   acc ||= v1 && v2
-  static mlir::Value genAccumulateProduct(mlir::Location loc,
-                                          fir::FirOpBuilder &builder,
-                                          mlir::Type resultType,
-                                          mlir::Value acc, mlir::Value v1,
-                                          mlir::Value v2) {
-    acc = castToProductType(loc, builder, acc, resultType);
-    v1 = castToProductType(loc, builder, v1, resultType);
-    v2 = castToProductType(loc, builder, v2, resultType);
-    mlir::Value result;
-    if (mlir::isa<mlir::FloatType>(resultType))
-      result = builder.create<mlir::arith::AddFOp>(
-          loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
-    else if (mlir::isa<mlir::ComplexType>(resultType))
-      result = builder.create<fir::AddcOp>(
-          loc, acc, builder.create<fir::MulcOp>(loc, v1, v2));
-    else if (mlir::isa<mlir::IntegerType>(resultType))
-      result = builder.create<mlir::arith::AddIOp>(
-          loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
-    else if (mlir::isa<fir::LogicalType>(resultType))
-      result = builder.create<mlir::arith::OrIOp>(
-          loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
-    else
-      llvm_unreachable("unsupported type");
-
-    return builder.createConvert(loc, resultType, result);
-  }
-
   static mlir::LogicalResult
   genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
                       hlfir::Entity result, mlir::Value resultShape,
@@ -748,9 +748,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, {I, K});
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, {K, J});
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, resultElementValue,
-            lhsElementValue, rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                resultElementValue, lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
         return {};
       };
@@ -785,9 +785,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, {J, K});
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, {K});
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, resultElementValue,
-            lhsElementValue, rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                resultElementValue, lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
         return {};
       };
@@ -817,9 +817,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, {K});
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, {K, J});
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, resultElementValue,
-            lhsElementValue, rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                resultElementValue, lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
         return {};
       };
@@ -885,9 +885,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
         hlfir::Entity rhsElementValue =
             hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
-        mlir::Value productValue = genAccumulateProduct(
-            loc, builder, resultElementType, reductionArgs[0], lhsElementValue,
-            rhsElementValue);
+        mlir::Value productValue =
+            ProductFactory{loc, builder}.genAccumulateProduct(
+                reductionArgs[0], lhsElementValue, rhsElementValue);
         return {productValue};
       };
       llvm::SmallVector<mlir::Value, 1> innerProductValue =

>From 2986f4ae1e7c4bdacac3f0ba64f40b5e21a82353 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 15 Jan 2025 12:00:46 -0800
Subject: [PATCH 2/2] [flang] Inline hlfir.dot_product.

Some good results for induct2, where dot_product is applied
to a vector of unknow size and a known 3-element vector:
the inlining ends up generating a 3-iteration loop, which
is then fully unrolled. With late FIR simplification
it is not happening even when the simplified intrinsics
implementation is inlined by LLVM (because the loop bounds
are not known).

This change just follows the current approach to expose
the loops for later worksharing application.
---
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    |  75 +++++++++
 .../simplify-hlfir-intrinsics-dotproduct.fir  | 144 ++++++++++++++++++
 2 files changed, 219 insertions(+)
 create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index f9dd6182cbde70..c4fb8cc56c75ef 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -904,6 +904,79 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
   }
 };
 
+class DotProductConversion
+    : public mlir::OpRewritePattern<hlfir::DotProductOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::DotProductOp product,
+                  mlir::PatternRewriter &rewriter) const override {
+    hlfir::Entity op = hlfir::Entity{product};
+    if (!op.isScalar())
+      return rewriter.notifyMatchFailure(product, "produces non-scalar result");
+
+    mlir::Location loc = product.getLoc();
+    fir::FirOpBuilder builder{rewriter, product.getOperation()};
+    hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
+    hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
+    mlir::Type resultElementType = product.getType();
+    bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
+                       mlir::isa<fir::LogicalType>(resultElementType) ||
+                       static_cast<bool>(builder.getFastMathFlags() &
+                                         mlir::arith::FastMathFlags::reassoc);
+
+    mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);
+
+    auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                       mlir::ValueRange oneBasedIndices,
+                       mlir::ValueRange reductionArgs)
+        -> llvm::SmallVector<mlir::Value, 1> {
+      hlfir::Entity lhsElementValue =
+          hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
+      hlfir::Entity rhsElementValue =
+          hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
+      mlir::Value productValue =
+          ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
+              reductionArgs[0], lhsElementValue, rhsElementValue);
+      return {productValue};
+    };
+
+    mlir::Value initValue =
+        fir::factory::createZeroValue(builder, loc, resultElementType);
+
+    llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
+        loc, builder, {extent},
+        /*reductionInits=*/{initValue}, genBody, isUnordered);
+
+    rewriter.replaceOp(product, result[0]);
+    return mlir::success();
+  }
+
+private:
+  static mlir::Value genProductExtent(mlir::Location loc,
+                                      fir::FirOpBuilder &builder,
+                                      hlfir::Entity input1,
+                                      hlfir::Entity input2) {
+    mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
+    llvm::SmallVector<mlir::Value, 1> input1Extents =
+        hlfir::getExplicitExtentsFromShape(input1Shape, builder);
+    if (input1Shape.getUses().empty())
+      input1Shape.getDefiningOp()->erase();
+    mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
+    llvm::SmallVector<mlir::Value, 1> input2Extents =
+        hlfir::getExplicitExtentsFromShape(input2Shape, builder);
+    if (input2Shape.getUses().empty())
+      input2Shape.getDefiningOp()->erase();
+
+    assert(input1Extents.size() == 1 && input2Extents.size() == 1 &&
+           "hlfir.dot_product arguments must be vectors");
+    llvm::SmallVector<mlir::Value, 1> extent =
+        fir::factory::deduceOptimalExtents(input1Extents, input2Extents);
+    return extent[0];
+  }
+};
+
 class SimplifyHLFIRIntrinsics
     : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
 public:
@@ -939,6 +1012,8 @@ class SimplifyHLFIRIntrinsics
     if (forceMatmulAsElemental || this->allowNewSideEffects)
       patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
 
+    patterns.insert<DotProductConversion>(context);
+
     if (mlir::failed(mlir::applyPatternsGreedily(
             getOperation(), std::move(patterns), config))) {
       mlir::emitError(getOperation()->getLoc(),
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir
new file mode 100644
index 00000000000000..f59b1422dbc849
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir
@@ -0,0 +1,144 @@
+// Test hlfir.dot_product simplification to a reduction loop:
+// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
+
+func.func @dot_product_integer(%arg0: !hlfir.expr<?xi16>, %arg1: !hlfir.expr<?xi32>) -> i32 {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xi16>, !hlfir.expr<?xi32>) -> i32
+  return %res : i32
+}
+// CHECK-LABEL:   func.func @dot_product_integer(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !hlfir.expr<?xi16>,
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !hlfir.expr<?xi32>) -> i32 {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xi16>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (i32) {
+// CHECK:             %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xi16>, index) -> i16
+// CHECK:             %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xi32>, index) -> i32
+// CHECK:             %[[VAL_11:.*]] = fir.convert %[[VAL_9]] : (i16) -> i32
+// CHECK:             %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_10]] : i32
+// CHECK:             %[[VAL_13:.*]] = arith.addi %[[VAL_8]], %[[VAL_12]] : i32
+// CHECK:             fir.result %[[VAL_13]] : i32
+// CHECK:           }
+// CHECK:           return %[[VAL_6]] : i32
+// CHECK:         }
+
+func.func @dot_product_real(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xf16>) -> f32 {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xf16>) -> f32
+  return %res : f32
+}
+// CHECK-LABEL:   func.func @dot_product_real(
+// CHECK-SAME:                                %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
+// CHECK-SAME:                                %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> f32 {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (f32) {
+// CHECK:             %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xf32>, index) -> f32
+// CHECK:             %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xf16>, index) -> f16
+// CHECK:             %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f16) -> f32
+// CHECK:             %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : f32
+// CHECK:             %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_12]] : f32
+// CHECK:             fir.result %[[VAL_13]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_6]] : f32
+// CHECK:         }
+
+func.func @dot_product_complex(%arg0: !hlfir.expr<?xcomplex<f32>>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xcomplex<f32>>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
+  return %res : complex<f32>
+}
+// CHECK-LABEL:   func.func @dot_product_complex(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !hlfir.expr<?xcomplex<f32>>,
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xcomplex<f32>>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.undefined complex<f32>
+// CHECK:           %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
+// CHECK:             %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f32>>, index) -> complex<f32>
+// CHECK:             %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
+// CHECK:             %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
+// CHECK:             %[[VAL_15:.*]] = fir.extract_value %[[VAL_12]], [1 : index] : (complex<f32>) -> f32
+// CHECK:             %[[VAL_16:.*]] = arith.negf %[[VAL_15]] : f32
+// CHECK:             %[[VAL_17:.*]] = fir.insert_value %[[VAL_12]], %[[VAL_16]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_18:.*]] = fir.mulc %[[VAL_17]], %[[VAL_14]] : complex<f32>
+// CHECK:             %[[VAL_19:.*]] = fir.addc %[[VAL_11]], %[[VAL_18]] : complex<f32>
+// CHECK:             fir.result %[[VAL_19]] : complex<f32>
+// CHECK:           }
+// CHECK:           return %[[VAL_9]] : complex<f32>
+// CHECK:         }
+
+func.func @dot_product_real_complex(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
+  return %res : complex<f32>
+}
+// CHECK-LABEL:   func.func @dot_product_real_complex(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
+// CHECK-SAME:                                        %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.undefined complex<f32>
+// CHECK:           %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:           %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
+// CHECK:             %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xf32>, index) -> f32
+// CHECK:             %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
+// CHECK:             %[[VAL_14:.*]] = fir.undefined complex<f32>
+// CHECK:             %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_17:.*]] = fir.insert_value %[[VAL_16]], %[[VAL_12]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_18:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
+// CHECK:             %[[VAL_19:.*]] = fir.extract_value %[[VAL_17]], [1 : index] : (complex<f32>) -> f32
+// CHECK:             %[[VAL_20:.*]] = arith.negf %[[VAL_19]] : f32
+// CHECK:             %[[VAL_21:.*]] = fir.insert_value %[[VAL_17]], %[[VAL_20]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// CHECK:             %[[VAL_22:.*]] = fir.mulc %[[VAL_21]], %[[VAL_18]] : complex<f32>
+// CHECK:             %[[VAL_23:.*]] = fir.addc %[[VAL_11]], %[[VAL_22]] : complex<f32>
+// CHECK:             fir.result %[[VAL_23]] : complex<f32>
+// CHECK:           }
+// CHECK:           return %[[VAL_9]] : complex<f32>
+// CHECK:         }
+
+func.func @dot_product_logical(%arg0: !hlfir.expr<?x!fir.logical<1>>, %arg1: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?x!fir.logical<1>>, !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+  return %res : !fir.logical<4>
+}
+// CHECK-LABEL:   func.func @dot_product_logical(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: !hlfir.expr<?x!fir.logical<1>>,
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant false
+// CHECK:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x!fir.logical<1>>) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// CHECK:           %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4>
+// CHECK:           %[[VAL_7:.*]] = fir.do_loop %[[VAL_8:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (!fir.logical<4>) {
+// CHECK:             %[[VAL_10:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<1>>, index) -> !fir.logical<1>
+// CHECK:             %[[VAL_11:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
+// CHECK:             %[[VAL_12:.*]] = fir.convert %[[VAL_9]] : (!fir.logical<4>) -> i1
+// CHECK:             %[[VAL_13:.*]] = fir.convert %[[VAL_10]] : (!fir.logical<1>) -> i1
+// CHECK:             %[[VAL_14:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<4>) -> i1
+// CHECK:             %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
+// CHECK:             %[[VAL_16:.*]] = arith.ori %[[VAL_12]], %[[VAL_15]] : i1
+// CHECK:             %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i1) -> !fir.logical<4>
+// CHECK:             fir.result %[[VAL_17]] : !fir.logical<4>
+// CHECK:           }
+// CHECK:           return %[[VAL_7]] : !fir.logical<4>
+// CHECK:         }
+
+func.func @dot_product_known_dim(%arg0: !hlfir.expr<10xf32>, %arg1: !hlfir.expr<?xi16>) -> f32 {
+  %res1 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<10xf32>, !hlfir.expr<?xi16>) -> f32
+  %res2 = hlfir.dot_product %arg1 %arg0 : (!hlfir.expr<?xi16>, !hlfir.expr<10xf32>) -> f32
+  %res = arith.addf %res1, %res2 : f32
+  return %res : f32
+}
+// CHECK-LABEL:   func.func @dot_product_known_dim(
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 10 : index
+// CHECK:           fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]
+// CHECK:           fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]



More information about the flang-commits mailing list