[flang-commits] [flang] [flang] Inline hlfir.matmul[_transpose]. (PR #122821)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Tue Jan 14 10:10:18 PST 2025
================
@@ -467,9 +474,431 @@ class CShiftAsElementalConversion
}
};
+template <typename Op>
+class MatmulConversion : public mlir::OpRewritePattern<Op> {
+public:
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = matmul.getLoc();
+ fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
+ hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
+ hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
+ mlir::Value resultShape, innerProductExtent;
+ std::tie(resultShape, innerProductExtent) =
+ genResultShape(loc, builder, lhs, rhs);
+
+ if (forceMatmulAsElemental || isMatmulTranspose) {
+ // Generate hlfir.elemental that produces the result of
+ // MATMUL/MATMUL(TRANSPOSE).
+ // Note that this implementation is very suboptimal for MATMUL,
+ // but is quite good for MATMUL(TRANSPOSE), e.g.:
+ // R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
+ // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
+ // in merging the inner product computation with the elemental
+ // addition. Note that the inner product computation will
+ // benefit from processing the lowermost dimensions of X and Y,
+ // which may be the best when they are contiguous.
+ //
+ // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
+ // MATMUL is inlined below by default unless forceMatmulAsElemental.
+ hlfir::ExprType resultType =
+ mlir::cast<hlfir::ExprType>(matmul.getType());
+ hlfir::ElementalOp newOp = genElementalMatmul(
+ loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
+ rewriter.replaceOp(matmul, newOp);
+ return mlir::success();
+ }
+
+ // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
+ // from Fortran runtime. The implementation needs to operate
+ // with the result array as an in-memory object.
+ hlfir::EvaluateInMemoryOp evalOp =
+ builder.create<hlfir::EvaluateInMemoryOp>(
+ loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
+ builder.setInsertionPointToStart(&evalOp.getBody().front());
+
+ // Embox the raw array pointer to simplify designating it.
+ // TODO: this currently results in redundant lower bounds
+ // addition for the designator, but this should be fixed in
+ // hlfir::Entity::mayHaveNonDefaultLowerBounds().
+ mlir::Value resultArray = evalOp.getMemory();
+ mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
+ resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
+ resultArray, resultShape, /*slice=*/nullptr,
+ /*lengths=*/{}, /*tdesc=*/nullptr);
+
+ // The contiguous MATMUL version is best for the cases
+ // where the input arrays and (maybe) the result are contiguous
+ // in their lowermost dimensions.
+ // Especially, when LLVM can recognize the continuity
+ // and vectorize the loops properly.
+ // TODO: we need to recognize the cases when the continuity
+ // is not statically obvious and try to generate an explicitly
+ // continuous version under a dynamic check. The fallback
+ // implementation may use genElementalMatmul() with
+ // an hlfir.assign into the result of eval_in_mem.
+ mlir::LogicalResult rewriteResult =
+ genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
+ resultShape, lhs, rhs, innerProductExtent);
----------------
vzakhari wrote:
Yep, it is the former :) I will update the comment.
https://github.com/llvm/llvm-project/pull/122821
More information about the flang-commits
mailing list