[flang-commits] [flang] 6bcfab3 - [flang][hlfir] allow recursive intrinsic lowering

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed Jun 7 07:36:45 PDT 2023


Author: Tom Eccles
Date: 2023-06-07T14:36:14Z
New Revision: 6bcfab3161ea3a90b26dce3f85ba052933060b73

URL: https://github.com/llvm/llvm-project/commit/6bcfab3161ea3a90b26dce3f85ba052933060b73
DIFF: https://github.com/llvm/llvm-project/commit/6bcfab3161ea3a90b26dce3f85ba052933060b73.diff

LOG: [flang][hlfir] allow recursive intrinsic lowering

We need to allow recursive application of intrinsic lowering patterns,
otherwise we cannot lower nested calls of the same intrinsic e.g.
matmul(matmul(a, b), c).

matmul(matmul(a, b), matmul(c, d)) requires hlfir.associate of hlfir
expr with more than one use (TODO).

Differential Revision: https://reviews.llvm.org/D152284

Added: 
    

Modified: 
    flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
    flang/test/HLFIR/matmul-lowering.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index ac7eafc862fd6..e70ef497906a7 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -1,4 +1,4 @@
-//===- LowerHLFIRIntrinsics.cpp - Bufferize HLFIR  ------------------------===//
+//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -37,7 +37,22 @@ namespace {
 /// runtime calls
 template <class OP>
 class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
-  using mlir::OpRewritePattern<OP>::OpRewritePattern;
+public:
+  explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
+      : mlir::OpRewritePattern<OP>{ctx} {
+    // required for cases where intrinsics are chained together e.g.
+    // matmul(matmul(a, b), c)
+    // because converting the inner operation then invalidates the
+    // outer operation: causing the pattern to apply recursively.
+    //
+    // This is safe because we always progress with each iteration. Circular
+    // applications of operations are not expressible in MLIR because we use
+    // an SSA form and one must become first. E.g.
+    // %a = hlfir.matmul %b %d
+    // %b = hlfir.matmul %a %d
+    // cannot be written.
+    mlir::OpConversionPattern<OP>::setHasBoundedRewriteRecursion(true);
+  }
 
 protected:
   struct IntrinsicArgument {

diff  --git a/flang/test/HLFIR/matmul-lowering.fir b/flang/test/HLFIR/matmul-lowering.fir
index d4819f18c62af..752afcee03a7e 100644
--- a/flang/test/HLFIR/matmul-lowering.fir
+++ b/flang/test/HLFIR/matmul-lowering.fir
@@ -43,3 +43,39 @@ func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lh
 // CHECK:         hlfir.destroy %[[ASEXPR]]
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
+
+// nested matmuls leading to recursive pattern application
+func.func @_QPtest(%arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"}, %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"}, %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
+  %c3 = arith.constant 3 : index
+  %c3_0 = arith.constant 3 : index
+  %0 = fir.shape %c3, %c3_0 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
+  %c3_1 = arith.constant 3 : index
+  %c3_2 = arith.constant 3 : index
+  %2 = fir.shape %c3_1, %c3_2 : (index, index) -> !fir.shape<2>
+  %3:2 = hlfir.declare %arg1(%2) {uniq_name = "_QFtestEb"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
+  %c3_3 = arith.constant 3 : index
+  %c3_4 = arith.constant 3 : index
+  %4 = fir.shape %c3_3, %c3_4 : (index, index) -> !fir.shape<2>
+  %5:2 = hlfir.declare %arg2(%4) {uniq_name = "_QFtestEc"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
+  %c3_5 = arith.constant 3 : index
+  %c3_6 = arith.constant 3 : index
+  %6 = fir.shape %c3_5, %c3_6 : (index, index) -> !fir.shape<2>
+  %7:2 = hlfir.declare %arg3(%6) {uniq_name = "_QFtestEout"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
+  %8 = hlfir.matmul %1#0 %3#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
+  %9 = hlfir.matmul %8 %5#0 {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
+  hlfir.assign %9 to %7#0 : !hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>
+  hlfir.destroy %9 : !hlfir.expr<3x3xf32>
+  hlfir.destroy %8 : !hlfir.expr<3x3xf32>
+  return
+}
+// just check that we apply the patterns successfully. The details are checked above
+// CHECK-LABEL: func.func @_QPtest(
+// CHECK:           %arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"},
+// CHECK-SAME:      %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"},
+// CHECK-SAME:      %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"},
+// CHECK-SAME:      %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
+// CHECK:         fir.call @_FortranAMatmul(
+// CHECK;         fir.call @_FortranAMatmul(%40, %41, %42, %43, %c20_i32) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+// CHECK:         return
+// CHECK-NEXT:  }


        


More information about the flang-commits mailing list