[flang-commits] [flang] [flang] Inline hlfir.matmul[_transpose]. (PR #122821)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Tue Jan 14 10:10:18 PST 2025


================
@@ -467,9 +474,431 @@ class CShiftAsElementalConversion
   }
 };
 
+template <typename Op>
+class MatmulConversion : public mlir::OpRewritePattern<Op> {
+public:
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = matmul.getLoc();
+    fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
+    hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
+    hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
+    mlir::Value resultShape, innerProductExtent;
+    std::tie(resultShape, innerProductExtent) =
+        genResultShape(loc, builder, lhs, rhs);
+
+    if (forceMatmulAsElemental || isMatmulTranspose) {
+      // Generate hlfir.elemental that produces the result of
+      // MATMUL/MATMUL(TRANSPOSE).
+      // Note that this implementation is very suboptimal for MATMUL,
+      // but is quite good for MATMUL(TRANSPOSE), e.g.:
+      //   R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
+      // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
+      // in merging the inner product computation with the elemental
+      // addition. Note that the inner product computation will
+      // benefit from processing the lowermost dimensions of X and Y,
+      // which may be the best when they are contiguous.
+      //
+      // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
+      // MATMUL is inlined below by default unless forceMatmulAsElemental.
+      hlfir::ExprType resultType =
+          mlir::cast<hlfir::ExprType>(matmul.getType());
+      hlfir::ElementalOp newOp = genElementalMatmul(
+          loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
+      rewriter.replaceOp(matmul, newOp);
+      return mlir::success();
+    }
+
+    // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
+    // from Fortran runtime. The implementation needs to operate
+    // with the result array as an in-memory object.
+    hlfir::EvaluateInMemoryOp evalOp =
+        builder.create<hlfir::EvaluateInMemoryOp>(
+            loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
+    builder.setInsertionPointToStart(&evalOp.getBody().front());
+
+    // Embox the raw array pointer to simplify designating it.
+    // TODO: this currently results in redundant lower bounds
+    // addition for the designator, but this should be fixed in
+    // hlfir::Entity::mayHaveNonDefaultLowerBounds().
+    mlir::Value resultArray = evalOp.getMemory();
+    mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
+    resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
+                                    resultArray, resultShape, /*slice=*/nullptr,
+                                    /*lengths=*/{}, /*tdesc=*/nullptr);
+
+    // The contiguous MATMUL version is best for the cases
+    // where the input arrays and (maybe) the result are contiguous
+    // in their lowermost dimensions.
+    // Especially, when LLVM can recognize the continuity
+    // and vectorize the loops properly.
+    // TODO: we need to recognize the cases when the continuity
+    // is not statically obvious and try to generate an explicitly
+    // continuous version under a dynamic check. The fallback
+    // implementation may use genElementalMatmul() with
+    // an hlfir.assign into the result of eval_in_mem.
+    mlir::LogicalResult rewriteResult =
+        genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
+                            resultShape, lhs, rhs, innerProductExtent);
----------------
vzakhari wrote:

Yep, it is the former :) I will update the comment.

https://github.com/llvm/llvm-project/pull/122821


More information about the flang-commits mailing list