[flang-commits] [flang] 3a6f3b4 - [flang] lower hlfir.matmul into fir runtime call

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Feb 16 07:32:15 PST 2023


Author: Tom Eccles
Date: 2023-02-16T15:30:46Z
New Revision: 3a6f3b44e434557d85f1049f5548de8e3d80bcb8

URL: https://github.com/llvm/llvm-project/commit/3a6f3b44e434557d85f1049f5548de8e3d80bcb8
DIFF: https://github.com/llvm/llvm-project/commit/3a6f3b44e434557d85f1049f5548de8e3d80bcb8.diff

LOG: [flang] lower hlfir.matmul into fir runtime call

We can't test lowering calls with hlfir.expr arguments yet because this
hits a not yet implemented: "get shape form HLFIR expr without producer
holding the shape".

Differential Revision: https://reviews.llvm.org/D144098

Added: 
    flang/test/HLFIR/matmul-bufferization.fir

Modified: 
    flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 2ada1225dabf5..c2181c686ca0c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -22,6 +22,7 @@
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
 #include "flang/Optimizer/Support/FIRContext.h"
@@ -29,6 +30,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <mlir/Support/LogicalResult.h>
 #include <optional>
 
 namespace hlfir {
@@ -518,7 +520,6 @@ class HlfirIntrinsicConversion : public mlir::OpConversionPattern<OP> {
                  const llvm::ArrayRef<IntrinsicArgument> &args,
                  mlir::ConversionPatternRewriter &rewriter,
                  const fir::IntrinsicArgumentLoweringRules *argLowering) const {
-    assert(args.size() == 3 && "Transformational intrinsics have 3 args");
     mlir::Location loc = op->getLoc();
     fir::KindMapping kindMapping{rewriter.getContext()};
     fir::FirOpBuilder builder{rewriter, kindMapping};
@@ -648,6 +649,39 @@ struct SumOpConversion : public HlfirIntrinsicConversion<hlfir::SumOp> {
   }
 };
 
+struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
+  using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    fir::KindMapping kindMapping{rewriter.getContext()};
+    fir::FirOpBuilder builder{rewriter, kindMapping};
+    const mlir::Location &loc = matmul->getLoc();
+    HLFIRListener listener{builder, rewriter};
+    builder.setListener(&listener);
+
+    mlir::Value lhs = matmul.getLhs();
+    mlir::Value rhs = matmul.getRhs();
+    llvm::SmallVector<IntrinsicArgument, 2> inArgs;
+    inArgs.push_back({lhs, lhs.getType()});
+    inArgs.push_back({rhs, rhs.getType()});
+
+    auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
+    llvm::SmallVector<fir::ExtendedValue, 2> args =
+        lowerArguments(matmul, inArgs, rewriter, argLowering);
+
+    mlir::Type scalarResultType =
+        hlfir::getFortranElementType(matmul.getType());
+
+    auto [resultExv, mustBeFreed] =
+        fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);
+
+    processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
+    return mlir::success();
+  }
+};
+
 class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
 public:
   void runOnOperation() override {
@@ -661,12 +695,11 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
     auto module = this->getOperation();
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns
-        .insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
-                AssociateOpConversion, ConcatOpConversion, DestroyOpConversion,
-                ElementalOpConversion, EndAssociateOpConversion,
-                NoReassocOpConversion, SetLengthOpConversion, SumOpConversion>(
-            context);
+    patterns.insert<
+        ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
+        AssociateOpConversion, ConcatOpConversion, DestroyOpConversion,
+        ElementalOpConversion, EndAssociateOpConversion, MatmulOpConversion,
+        NoReassocOpConversion, SetLengthOpConversion, SumOpConversion>(context);
     mlir::ConversionTarget target(*context);
     target.addIllegalOp<hlfir::ApplyOp, hlfir::AssociateOp, hlfir::ElementalOp,
                         hlfir::EndAssociateOp, hlfir::SetLengthOp,

diff  --git a/flang/test/HLFIR/matmul-bufferization.fir b/flang/test/HLFIR/matmul-bufferization.fir
new file mode 100644
index 0000000000000..54da40cf303f9
--- /dev/null
+++ b/flang/test/HLFIR/matmul-bufferization.fir
@@ -0,0 +1,45 @@
+// Test hlfir.matmul operation lowering to fir runtime call
+// RUN: fir-opt %s -bufferize-hlfir | FileCheck %s
+
+func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lhs"}, %arg1: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "rhs"}, %arg2: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "res"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFmatmul1Elhs"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
+  %1:2 = hlfir.declare %arg2 {uniq_name = "_QFmatmul1Eres"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFmatmul1Erhs"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
+  %3 = hlfir.matmul %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<?x?xi32>
+  hlfir.assign %3 to %1#0 : !hlfir.expr<?x?xi32>, !fir.box<!fir.array<?x?xi32>>
+  hlfir.destroy %3 : !hlfir.expr<?x?xi32>
+  return
+}
+// CHECK-LABEL: func.func @_QPmatmul1(
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lhs"}
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "rhs"}
+// CHECK:           %[[ARG2:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "res"}
+// CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]]
+// CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]]
+// CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
+
+// CHECK-DAG:     %[[RET_BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xi32>>>
+// CHECK-DAG:     %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap<!fir.array<?x?xi32>>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[RET_SHAPE:.*]] = fir.shape %[[C0]], %[[C0]] : (index, index) -> !fir.shape<2>
+// CHECK-DAG:     %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]])
+// CHECK-DAG:     fir.store %[[RET_EMBOX]] to %[[RET_BOX]]
+
+// CHECK:         %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK-DAG:     %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
+// CHECK-DAG:     %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranAMatmul(%[[RET_ARG]], %[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]])
+
+// CHECK:         %[[RET:.*]] = fir.load %[[RET_BOX]]
+// CHECK-DAG:     %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
+// CHECK-DAG:     %[[ADDR:.*]] = fir.box_addr %[[RET]]
+// CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
+// TODO: fix alias analysis in hlfir.assign bufferization
+// CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
+// CHECK:         %[[TUPLE0:.*]] = fir.undefined tuple<!fir.box<!fir.array<?x?xi32>>, i1>
+// CHECK:         %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[TRUE:.*]], [1 : index]
+// CHECK:         %[[TUPLE2:.*]] = fir.insert_value %[[TUPLE1]], %[[TMP]]#0, [0 : index]
+// CHECK:         hlfir.assign %[[TMP]]#0 to %[[RES_VAR]]#0
+// CHECK:         fir.freemem %[[TMP]]#1
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }


        


More information about the flang-commits mailing list