[flang-commits] [flang] 43eb96c - [flang][hlfir] lower hlfir.matmul_transpose to runtime call

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Fri Mar 17 02:31:02 PDT 2023


Author: Tom Eccles
Date: 2023-03-17T09:30:04Z
New Revision: 43eb96cab8f734a740b86c553f48f96fa9d19edc

URL: https://github.com/llvm/llvm-project/commit/43eb96cab8f734a740b86c553f48f96fa9d19edc
DIFF: https://github.com/llvm/llvm-project/commit/43eb96cab8f734a740b86c553f48f96fa9d19edc.diff

LOG: [flang][hlfir] lower hlfir.matmul_transpose to runtime call

Depends on D145960

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D145961

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
    flang/lib/Optimizer/Builder/IntrinsicCall.cpp
    flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
    flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
    flang/test/HLFIR/mul_transpose.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
index f084e2eb7ae3a..ae0a0979902f5 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
@@ -55,6 +55,10 @@ void genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
                mlir::Value matrixABox, mlir::Value matrixBBox,
                mlir::Value resultBox);
 
+void genMatmulTranspose(fir::FirOpBuilder &builder, mlir::Location loc,
+                        mlir::Value matrixABox, mlir::Value matrixBBox,
+                        mlir::Value resultBox);
+
 void genPack(fir::FirOpBuilder &builder, mlir::Location loc,
              mlir::Value resultBox, mlir::Value arrayBox, mlir::Value maskBox,
              mlir::Value vectorBox);

diff  --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index bfc36b8a8f90f..b933603484581 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -269,6 +269,8 @@ struct IntrinsicLibrary {
   template <typename Shift>
   mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+  fir::ExtendedValue genMatmulTranspose(mlir::Type,
+                                        llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMerge(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@@ -679,6 +681,10 @@ static constexpr IntrinsicHandler handlers[]{
      &I::genMatmul,
      {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
      /*isElemental=*/false},
+    {"matmul_transpose",
+     &I::genMatmulTranspose,
+     {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
+     /*isElemental=*/false},
     {"max", &I::genExtremum<Extremum::Max, ExtremumBehavior::MinMaxss>},
     {"maxloc",
      &I::genMaxloc,
@@ -4015,6 +4021,33 @@ IntrinsicLibrary::genMatmul(mlir::Type resultType,
   return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL");
 }
 
+// MATMUL_TRANSPOSE
+fir::ExtendedValue
+IntrinsicLibrary::genMatmulTranspose(mlir::Type resultType,
+                                     llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 2);
+
+  // Handle required matmul_transpose arguments
+  fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]);
+  mlir::Value matrixA = fir::getBase(matrixTmpA);
+  fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
+  mlir::Value matrixB = fir::getBase(matrixTmpB);
+  unsigned resultRank =
+      (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
+
+  // Create mutable fir.box to be passed to the runtime for the result.
+  mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
+  fir::MutableBoxValue resultMutableBox =
+      fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+  mlir::Value resultIrBox =
+      fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+  // Call runtime. The runtime is allocating the result.
+  fir::runtime::genMatmulTranspose(builder, loc, resultIrBox, matrixA, matrixB);
+  // Read result from mutable fir.box and add it to the list of temps to be
+  // finalized by the StatementContext.
+  return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL_TRANSPOSE");
+}
+
 // MERGE
 fir::ExtendedValue
 IntrinsicLibrary::genMerge(mlir::Type,

diff  --git a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
index c51f38672d25f..3ae097743e3ad 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
@@ -13,6 +13,7 @@
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Runtime/matmul-transpose.h"
 #include "flang/Runtime/matmul.h"
 #include "flang/Runtime/transformational.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -351,6 +352,23 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
   builder.create<fir::CallOp>(loc, func, args);
 }
 
+/// Generate call to MatmulTranspose intrinsic runtime routine.
+void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
+                                      mlir::Location loc, mlir::Value resultBox,
+                                      mlir::Value matrixABox,
+                                      mlir::Value matrixBBox) {
+  auto func =
+      fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
+  auto fTy = func.getFunctionType();
+  auto sourceFile = fir::factory::locationToFilename(builder, loc);
+  auto sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+  auto args =
+      fir::runtime::createArguments(builder, loc, fTy, resultBox, matrixABox,
+                                    matrixBBox, sourceFile, sourceLine);
+  builder.create<fir::CallOp>(loc, func, args);
+}
+
 /// Generate call to Pack intrinsic runtime routine.
 void fir::runtime::genPack(fir::FirOpBuilder &builder, mlir::Location loc,
                            mlir::Value resultBox, mlir::Value arrayBox,

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 59e5b3b53c6ac..95192a3e79d49 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -257,6 +257,39 @@ class TransposeOpConversion
   }
 };
 
+struct MatmulTransposeOpConversion
+    : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
+  using HlfirIntrinsicConversion<
+      hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
+                  mlir::PatternRewriter &rewriter) const override {
+    fir::KindMapping kindMapping{rewriter.getContext()};
+    fir::FirOpBuilder builder{rewriter, kindMapping};
+    const mlir::Location &loc = multranspose->getLoc();
+
+    mlir::Value lhs = multranspose.getLhs();
+    mlir::Value rhs = multranspose.getRhs();
+    llvm::SmallVector<IntrinsicArgument, 2> inArgs;
+    inArgs.push_back({lhs, lhs.getType()});
+    inArgs.push_back({rhs, rhs.getType()});
+
+    auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
+    llvm::SmallVector<fir::ExtendedValue, 2> args =
+        lowerArguments(multranspose, inArgs, rewriter, argLowering);
+
+    mlir::Type scalarResultType =
+        hlfir::getFortranElementType(multranspose.getType());
+
+    auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
+        builder, loc, "matmul_transpose", scalarResultType, args);
+
+    processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
+    return mlir::success();
+  }
+};
+
 class LowerHLFIRIntrinsics
     : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
 public:
@@ -271,13 +304,14 @@ class LowerHLFIRIntrinsics
     mlir::ModuleOp module = this->getOperation();
     mlir::MLIRContext *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns.insert<MatmulOpConversion, SumOpConversion, TransposeOpConversion>(
-        context);
+    patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
+                    SumOpConversion, TransposeOpConversion>(context);
     mlir::ConversionTarget target(*context);
     target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
                            mlir::func::FuncDialect, fir::FIROpsDialect,
                            hlfir::hlfirDialect>();
-    target.addIllegalOp<hlfir::MatmulOp, hlfir::SumOp, hlfir::TransposeOp>();
+    target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
+                        hlfir::TransposeOp>();
     target.markUnknownOpDynamicallyLegal(
         [](mlir::Operation *) { return true; });
     if (mlir::failed(

diff  --git a/flang/test/HLFIR/mul_transpose.f90 b/flang/test/HLFIR/mul_transpose.f90
index ab3742f88043e..c79c742efc441 100644
--- a/flang/test/HLFIR/mul_transpose.f90
+++ b/flang/test/HLFIR/mul_transpose.f90
@@ -1,6 +1,7 @@
 ! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s
+! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING-OPT --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s
 
 ! Test passing a hlfir.expr from one intrinsic to another
@@ -56,6 +57,20 @@ subroutine mul_transpose(a, b, res)
 ! CHECK-LOWERING-NEXT:  hlfir.destroy %[[MUL_EXPR]]
 ! CHECK-LOWERING-NEXT:  hlfir.destroy %[[TRANSPOSE_EXPR]]
 
+! CHECK-LOWERING-OPT:   %[[LHS_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT:   %[[B_BOX:.*]] = fir.embox %[[B_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT:   %[[MUL_CONV_RES:.*]] = fir.convert %[[MUL_RES_BOX:.*]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+! CHECK-LOWERING-OPT:   %[[LHS_CONV:.*]] = fir.convert %[[LHS_BOX]] : (!fir.box<!fir.array<2x1xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT:   %[[B_BOX_CONV:.*]] = fir.convert %[[B_BOX]] : (!fir.box<!fir.array<2x2xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT:   fir.call @_FortranAMatmulTranspose(%[[MUL_CONV_RES]], %[[LHS_CONV]], %[[B_BOX_CONV]], %[[LOC_STR2:.*]], %[[LOC_N2:.*]])
+! CHECK-LOWERING-OPT:   %[[MUL_RES_LD:.*]] = fir.load %[[MUL_RES_BOX:.*]]
+! CHECK-LOWERING-OPT:   %[[MUL_RES_ADDR:.*]] = fir.box_addr %[[MUL_RES_LD]]
+! CHECK-LOWERING-OPT:   %[[MUL_RES_VAR:.*]]:2 = hlfir.declare %[[MUL_RES_ADDR]]({{.*}}) {uniq_name = ".tmp.intrinsic_result"}
+! CHECK-LOWERING-OPT:   %[[TRUE2:.*]] = arith.constant true
+! CHECK-LOWERING-OPT:   %[[MUL_EXPR:.*]] = hlfir.as_expr %[[MUL_RES_VAR]]#0 move %[[TRUE2]] : (!fir.box<!fir.array<?x?xf32>>, i1) -> !hlfir.expr<?x?xf32>
+! CHECK-LOWERING-OPT:   hlfir.assign %[[MUL_EXPR]] to %[[RES_DECL]]#0 : !hlfir.expr<?x?xf32>, !fir.ref<!fir.array<1x2xf32>>
+! CHECK-LOWERING-OPT:   hlfir.destroy %[[MUL_EXPR]]
+
 ! [argument handling unchanged]
 ! CHECK-BUFFERING:      fir.call @_FortranATranspose(
 ! CHECK-BUFFERING:      %[[TRANSPOSE_RES_LD:.*]] = fir.load %[[TRANSPOSE_RES_BOX:.*]]


        


More information about the flang-commits mailing list