[flang-commits] [flang] 43eb96c - [flang][hlfir] lower hlfir.matmul_transpose to runtime call
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Fri Mar 17 02:31:02 PDT 2023
Author: Tom Eccles
Date: 2023-03-17T09:30:04Z
New Revision: 43eb96cab8f734a740b86c553f48f96fa9d19edc
URL: https://github.com/llvm/llvm-project/commit/43eb96cab8f734a740b86c553f48f96fa9d19edc
DIFF: https://github.com/llvm/llvm-project/commit/43eb96cab8f734a740b86c553f48f96fa9d19edc.diff
LOG: [flang][hlfir] lower hlfir.matmul_transpose to runtime call
Depends on D145960
Reviewed By: jeanPerier
Differential Revision: https://reviews.llvm.org/D145961
Added:
Modified:
flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
flang/test/HLFIR/mul_transpose.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
index f084e2eb7ae3a..ae0a0979902f5 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
@@ -55,6 +55,10 @@ void genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value matrixABox, mlir::Value matrixBBox,
mlir::Value resultBox);
+void genMatmulTranspose(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value matrixABox, mlir::Value matrixBBox,
+ mlir::Value resultBox);
+
void genPack(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value maskBox,
mlir::Value vectorBox);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index bfc36b8a8f90f..b933603484581 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -269,6 +269,8 @@ struct IntrinsicLibrary {
template <typename Shift>
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+ fir::ExtendedValue genMatmulTranspose(mlir::Type,
+ llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMerge(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@@ -679,6 +681,10 @@ static constexpr IntrinsicHandler handlers[]{
&I::genMatmul,
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
/*isElemental=*/false},
+ {"matmul_transpose",
+ &I::genMatmulTranspose,
+ {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
+ /*isElemental=*/false},
{"max", &I::genExtremum<Extremum::Max, ExtremumBehavior::MinMaxss>},
{"maxloc",
&I::genMaxloc,
@@ -4015,6 +4021,33 @@ IntrinsicLibrary::genMatmul(mlir::Type resultType,
return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL");
}
+// MATMUL_TRANSPOSE
+fir::ExtendedValue
+IntrinsicLibrary::genMatmulTranspose(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+
+ // Handle required matmul_transpose arguments
+ fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]);
+ mlir::Value matrixA = fir::getBase(matrixTmpA);
+ fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
+ mlir::Value matrixB = fir::getBase(matrixTmpB);
+ unsigned resultRank =
+ (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
+
+ // Create mutable fir.box to be passed to the runtime for the result.
+ mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
+ fir::MutableBoxValue resultMutableBox =
+ fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+ mlir::Value resultIrBox =
+ fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+ // Call runtime. The runtime is allocating the result.
+ fir::runtime::genMatmulTranspose(builder, loc, resultIrBox, matrixA, matrixB);
+ // Read result from mutable fir.box and add it to the list of temps to be
+ // finalized by the StatementContext.
+ return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL_TRANSPOSE");
+}
+
// MERGE
fir::ExtendedValue
IntrinsicLibrary::genMerge(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
index c51f38672d25f..3ae097743e3ad 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
@@ -13,6 +13,7 @@
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Runtime/matmul-transpose.h"
#include "flang/Runtime/matmul.h"
#include "flang/Runtime/transformational.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -351,6 +352,23 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
builder.create<fir::CallOp>(loc, func, args);
}
+/// Generate call to MatmulTranspose intrinsic runtime routine.
+void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value resultBox,
+ mlir::Value matrixABox,
+ mlir::Value matrixBBox) {
+ auto func =
+ fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ auto sourceFile = fir::factory::locationToFilename(builder, loc);
+ auto sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+ auto args =
+ fir::runtime::createArguments(builder, loc, fTy, resultBox, matrixABox,
+ matrixBBox, sourceFile, sourceLine);
+ builder.create<fir::CallOp>(loc, func, args);
+}
+
/// Generate call to Pack intrinsic runtime routine.
void fir::runtime::genPack(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, mlir::Value arrayBox,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 59e5b3b53c6ac..95192a3e79d49 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -257,6 +257,39 @@ class TransposeOpConversion
}
};
+struct MatmulTransposeOpConversion
+ : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
+ using HlfirIntrinsicConversion<
+ hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
+
+ mlir::LogicalResult
+ matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
+ mlir::PatternRewriter &rewriter) const override {
+ fir::KindMapping kindMapping{rewriter.getContext()};
+ fir::FirOpBuilder builder{rewriter, kindMapping};
+ const mlir::Location &loc = multranspose->getLoc();
+
+ mlir::Value lhs = multranspose.getLhs();
+ mlir::Value rhs = multranspose.getRhs();
+ llvm::SmallVector<IntrinsicArgument, 2> inArgs;
+ inArgs.push_back({lhs, lhs.getType()});
+ inArgs.push_back({rhs, rhs.getType()});
+
+ auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
+ llvm::SmallVector<fir::ExtendedValue, 2> args =
+ lowerArguments(multranspose, inArgs, rewriter, argLowering);
+
+ mlir::Type scalarResultType =
+ hlfir::getFortranElementType(multranspose.getType());
+
+ auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
+ builder, loc, "matmul_transpose", scalarResultType, args);
+
+ processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
+ return mlir::success();
+ }
+};
+
class LowerHLFIRIntrinsics
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
public:
@@ -271,13 +304,14 @@ class LowerHLFIRIntrinsics
mlir::ModuleOp module = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
- patterns.insert<MatmulOpConversion, SumOpConversion, TransposeOpConversion>(
- context);
+ patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
+ SumOpConversion, TransposeOpConversion>(context);
mlir::ConversionTarget target(*context);
target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect, fir::FIROpsDialect,
hlfir::hlfirDialect>();
- target.addIllegalOp<hlfir::MatmulOp, hlfir::SumOp, hlfir::TransposeOp>();
+ target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
+ hlfir::TransposeOp>();
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(
diff --git a/flang/test/HLFIR/mul_transpose.f90 b/flang/test/HLFIR/mul_transpose.f90
index ab3742f88043e..c79c742efc441 100644
--- a/flang/test/HLFIR/mul_transpose.f90
+++ b/flang/test/HLFIR/mul_transpose.f90
@@ -1,6 +1,7 @@
! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s
+! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING-OPT --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s
! Test passing a hlfir.expr from one intrinsic to another
@@ -56,6 +57,20 @@ subroutine mul_transpose(a, b, res)
! CHECK-LOWERING-NEXT: hlfir.destroy %[[MUL_EXPR]]
! CHECK-LOWERING-NEXT: hlfir.destroy %[[TRANSPOSE_EXPR]]
+! CHECK-LOWERING-OPT: %[[LHS_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT: %[[B_BOX:.*]] = fir.embox %[[B_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT: %[[MUL_CONV_RES:.*]] = fir.convert %[[MUL_RES_BOX:.*]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+! CHECK-LOWERING-OPT: %[[LHS_CONV:.*]] = fir.convert %[[LHS_BOX]] : (!fir.box<!fir.array<2x1xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT: %[[B_BOX_CONV:.*]] = fir.convert %[[B_BOX]] : (!fir.box<!fir.array<2x2xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT: fir.call @_FortranAMatmulTranspose(%[[MUL_CONV_RES]], %[[LHS_CONV]], %[[B_BOX_CONV]], %[[LOC_STR2:.*]], %[[LOC_N2:.*]])
+! CHECK-LOWERING-OPT: %[[MUL_RES_LD:.*]] = fir.load %[[MUL_RES_BOX:.*]]
+! CHECK-LOWERING-OPT: %[[MUL_RES_ADDR:.*]] = fir.box_addr %[[MUL_RES_LD]]
+! CHECK-LOWERING-OPT: %[[MUL_RES_VAR:.*]]:2 = hlfir.declare %[[MUL_RES_ADDR]]({{.*}}) {uniq_name = ".tmp.intrinsic_result"}
+! CHECK-LOWERING-OPT: %[[TRUE2:.*]] = arith.constant true
+! CHECK-LOWERING-OPT: %[[MUL_EXPR:.*]] = hlfir.as_expr %[[MUL_RES_VAR]]#0 move %[[TRUE2]] : (!fir.box<!fir.array<?x?xf32>>, i1) -> !hlfir.expr<?x?xf32>
+! CHECK-LOWERING-OPT: hlfir.assign %[[MUL_EXPR]] to %[[RES_DECL]]#0 : !hlfir.expr<?x?xf32>, !fir.ref<!fir.array<1x2xf32>>
+! CHECK-LOWERING-OPT: hlfir.destroy %[[MUL_EXPR]]
+
! [argument handling unchanged]
! CHECK-BUFFERING: fir.call @_FortranATranspose(
! CHECK-BUFFERING: %[[TRANSPOSE_RES_LD:.*]] = fir.load %[[TRANSPOSE_RES_BOX:.*]]
More information about the flang-commits
mailing list