[flang-commits] [flang] [flang] Inline hlfir.dot_product. (PR #123143)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 16 05:56:37 PST 2025


================
@@ -904,6 +904,79 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
   }
 };
 
+class DotProductConversion
+    : public mlir::OpRewritePattern<hlfir::DotProductOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::DotProductOp product,
+                  mlir::PatternRewriter &rewriter) const override {
+    hlfir::Entity op = hlfir::Entity{product};
+    if (!op.isScalar())
+      return rewriter.notifyMatchFailure(product, "produces non-scalar result");
+
+    mlir::Location loc = product.getLoc();
+    fir::FirOpBuilder builder{rewriter, product.getOperation()};
+    hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
+    hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
+    mlir::Type resultElementType = product.getType();
+    bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
+                       mlir::isa<fir::LogicalType>(resultElementType) ||
+                       static_cast<bool>(builder.getFastMathFlags() &
+                                         mlir::arith::FastMathFlags::reassoc);
+
+    mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);
+
+    auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                       mlir::ValueRange oneBasedIndices,
+                       mlir::ValueRange reductionArgs)
+        -> llvm::SmallVector<mlir::Value, 1> {
+      hlfir::Entity lhsElementValue =
+          hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
+      hlfir::Entity rhsElementValue =
+          hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
+      mlir::Value productValue =
+          ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
+              reductionArgs[0], lhsElementValue, rhsElementValue);
+      return {productValue};
+    };
+
+    mlir::Value initValue =
+        fir::factory::createZeroValue(builder, loc, resultElementType);
+
+    llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
+        loc, builder, {extent},
+        /*reductionInits=*/{initValue}, genBody, isUnordered);
+
+    rewriter.replaceOp(product, result[0]);
+    return mlir::success();
+  }
+
+private:
+  static mlir::Value genProductExtent(mlir::Location loc,
+                                      fir::FirOpBuilder &builder,
+                                      hlfir::Entity input1,
+                                      hlfir::Entity input2) {
+    mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
+    llvm::SmallVector<mlir::Value, 1> input1Extents =
+        hlfir::getExplicitExtentsFromShape(input1Shape, builder);
+    if (input1Shape.getUses().empty())
+      input1Shape.getDefiningOp()->erase();
----------------
jeanPerier wrote:

Maybe this pattern should be placed in an hlfir helper (the same will be needed for many intrinsics, and the op erasure is a bit distracting).

https://github.com/llvm/llvm-project/pull/123143


More information about the flang-commits mailing list