[flang-commits] [flang] [flang] Inline hlfir.matmul[_transpose]. (PR #122821)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Wed Jan 15 08:36:09 PST 2025
================
@@ -467,9 +474,431 @@ class CShiftAsElementalConversion
}
};
+template <typename Op>
+class MatmulConversion : public mlir::OpRewritePattern<Op> {
+public:
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = matmul.getLoc();
+ fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
+ hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
+ hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
+ mlir::Value resultShape, innerProductExtent;
+ std::tie(resultShape, innerProductExtent) =
+ genResultShape(loc, builder, lhs, rhs);
+
+ if (forceMatmulAsElemental || isMatmulTranspose) {
+ // Generate hlfir.elemental that produces the result of
+ // MATMUL/MATMUL(TRANSPOSE).
+ // Note that this implementation is very suboptimal for MATMUL,
+ // but is quite good for MATMUL(TRANSPOSE), e.g.:
+ // R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
+ // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
+ // in merging the inner product computation with the elemental
+ // addition. Note that the inner product computation will
+ // benefit from processing the lowermost dimensions of X and Y,
+ // which may be the best when they are contiguous.
+ //
+ // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
+ // MATMUL is inlined below by default unless forceMatmulAsElemental.
+ hlfir::ExprType resultType =
+ mlir::cast<hlfir::ExprType>(matmul.getType());
+ hlfir::ElementalOp newOp = genElementalMatmul(
+ loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
+ rewriter.replaceOp(matmul, newOp);
+ return mlir::success();
+ }
+
+ // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
+ // from Fortran runtime. The implementation needs to operate
+ // with the result array as an in-memory object.
+ hlfir::EvaluateInMemoryOp evalOp =
+ builder.create<hlfir::EvaluateInMemoryOp>(
+ loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
+ builder.setInsertionPointToStart(&evalOp.getBody().front());
+
+ // Embox the raw array pointer to simplify designating it.
+ // TODO: this currently results in redundant lower bounds
+ // addition for the designator, but this should be fixed in
+ // hlfir::Entity::mayHaveNonDefaultLowerBounds().
+ mlir::Value resultArray = evalOp.getMemory();
+ mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
+ resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
+ resultArray, resultShape, /*slice=*/nullptr,
+ /*lengths=*/{}, /*tdesc=*/nullptr);
+
+ // The contiguous MATMUL version is best for the cases
+ // where the input arrays and (maybe) the result are contiguous
+ // in their lowermost dimensions.
+ // Especially, when LLVM can recognize the continuity
+ // and vectorize the loops properly.
+ // TODO: we need to recognize the cases when the continuity
+ // is not statically obvious and try to generate an explicitly
+ // continuous version under a dynamic check. The fallback
+ // implementation may use genElementalMatmul() with
+ // an hlfir.assign into the result of eval_in_mem.
+ mlir::LogicalResult rewriteResult =
+ genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
+ resultShape, lhs, rhs, innerProductExtent);
+
+ if (mlir::failed(rewriteResult)) {
+ // Erase the unclaimed eval_in_mem op.
+ rewriter.eraseOp(evalOp);
+ return rewriter.notifyMatchFailure(matmul,
+ "genContiguousMatmul() failed");
+ }
+
+ rewriter.replaceOp(matmul, evalOp);
+ return mlir::success();
+ }
+
+private:
+ static constexpr bool isMatmulTranspose =
+ std::is_same_v<Op, hlfir::MatmulTransposeOp>;
+
+ // Return a tuple of:
+ // * A fir.shape operation representing the shape of the result
+ // of a MATMUL/MATMUL(TRANSPOSE).
+ // * An extent of the dimensions of the input array
+ // that are processed during the inner product computation.
+ static std::tuple<mlir::Value, mlir::Value>
+ genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity input1, hlfir::Entity input2) {
+ mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
+ llvm::SmallVector<mlir::Value> input1Extents =
+ hlfir::getExplicitExtentsFromShape(input1Shape, builder);
+ if (input1Shape.getUses().empty())
+ input1Shape.getDefiningOp()->erase();
+ mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
+ llvm::SmallVector<mlir::Value> input2Extents =
+ hlfir::getExplicitExtentsFromShape(input2Shape, builder);
+ if (input2Shape.getUses().empty())
+ input2Shape.getDefiningOp()->erase();
+
+ llvm::SmallVector<mlir::Value, 2> newExtents;
+ mlir::Value innerProduct1Extent, innerProduct2Extent;
+ if (input1Extents.size() == 1) {
+ assert(!isMatmulTranspose &&
+ "hlfir.matmul_transpose's first operand must be rank-2 array");
+ assert(input2Extents.size() == 2 &&
+ "hlfir.matmul second argument must be rank-2 array");
+ newExtents.push_back(input2Extents[1]);
+ innerProduct1Extent = input1Extents[0];
+ innerProduct2Extent = input2Extents[0];
+ } else {
+ if (input2Extents.size() == 1) {
+ assert(input1Extents.size() == 2 &&
+ "hlfir.matmul first argument must be rank-2 array");
+ if constexpr (isMatmulTranspose)
+ newExtents.push_back(input1Extents[1]);
+ else
+ newExtents.push_back(input1Extents[0]);
+ } else {
+ assert(input1Extents.size() == 2 && input2Extents.size() == 2 &&
+ "hlfir.matmul arguments must be rank-2 arrays");
+ if constexpr (isMatmulTranspose)
+ newExtents.push_back(input1Extents[1]);
+ else
+ newExtents.push_back(input1Extents[0]);
+
+ newExtents.push_back(input2Extents[1]);
+ }
+ if constexpr (isMatmulTranspose)
+ innerProduct1Extent = input1Extents[0];
+ else
+ innerProduct1Extent = input1Extents[1];
+
+ innerProduct2Extent = input2Extents[0];
+ }
+ // The inner product dimensions of the input arrays
+ // must match. Pick the best (e.g. constant) out of them
+ // so that the inner product loop bound can be used in
+ // optimizations.
+ llvm::SmallVector<mlir::Value> innerProductExtent =
+ fir::factory::deduceOptimalExtents({innerProduct1Extent},
+ {innerProduct2Extent});
+ return {builder.create<fir::ShapeOp>(loc, newExtents),
+ innerProductExtent[0]};
+ }
+
+ static mlir::Value castToProductType(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Value value, mlir::Type type) {
+ if (mlir::isa<fir::LogicalType>(type))
+ return builder.createConvert(loc, builder.getIntegerType(1), value);
+
+ if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
+ mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
----------------
vzakhari wrote:
That is in general, but in case of `complex * real` it should be correct.
https://github.com/llvm/llvm-project/pull/122821
More information about the flang-commits
mailing list