[flang-commits] [flang] [flang] Canonicalize fir.array_coor by pulling in embox/rebox. (PR #92858)

via flang-commits flang-commits at lists.llvm.org
Mon May 20 22:15:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

<details>
<summary>Changes</summary>

In a simple case like this:
```
program test
  integer :: u(120, 2)
  u(1:120,1:2) = u(1:120,1:2) + 2
end program
```
Flang is creating a copy loop with fir.array_coor using
a result of fir.embox inserted before the loop. This results in split
address computations before and inside the loop, which can be seen
as many more arithmetic operations than required after converting
FIR to LLVM dialect. Even though LLVM SROA/mem2reg are able
to optimize the temporary descriptor, and then LICM is able to hoist
the invariant computations, we seem to get better mix of LLVM dialect
operations after FIR-to-LLVM codegen. This may also slightly reduce
the compilation time taken by LLVM to optimize the generate LLVM IR.
This may also slightly reduce the time spent by FIR AliasAnalysis
to reach the memory reference source.

This patch also includes one change in the CodeGen to fix what I believe
is a bug: the indices used in fir.array_coor with a slice are
always ranging from 1, while the current code assumed they were
in the range of the slice bounds.


---
Full diff: https://github.com/llvm/llvm-project/pull/92858.diff


5 Files Affected:

- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+1) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+5-1) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+93) 
- (added) flang/test/Fir/array-coor-canonicalization.fir (+121) 
- (modified) flang/test/Fir/convert-to-llvm.fir (+2-2) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index d9c1149040066..7ffa0145072d6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -1650,6 +1650,7 @@ def fir_ArrayCoorOp : fir_Op<"array_coor",
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def fir_CoordinateOp : fir_Op<"coordinate_of", [NoMemoryEffect]> {
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 72172f63888e1..8d70ceb31bce1 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2162,7 +2162,11 @@ struct XArrayCoorOpConversion
         if (normalSlice)
           step = integerCast(loc, rewriter, idxTy, operands[sliceOffset + 2]);
       }
-      auto idx = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, index, lb, nsw);
+      // The array_coor indices are always 1-based if slicing
+      // is in effect. The non-default lower bounds only
+      // apply to the indices of the slice itself.
+      auto idx = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, index,
+                                                    isSliced ? one : lb, nsw);
       mlir::Value diff =
           rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, idx, step, nsw);
       if (normalSlice) {
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 94113da9a46cf..67320d4e6e7eb 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -394,11 +394,17 @@ mlir::LogicalResult fir::ArrayCoorOp::verify() {
     } else {
       auto s = mlir::cast<fir::ShiftType>(shapeTy);
       shapeTyRank = s.getRank();
+      // TODO: it looks like PreCGRewrite and CodeGen can support
+      // fir.shift with plain array reference, so we may consider
+      // removing this check.
       if (!mlir::isa<fir::BaseBoxType>(getMemref().getType()))
         return emitOpError("shift can only be provided with fir.box memref");
     }
     if (arrDim && arrDim != shapeTyRank)
       return emitOpError("rank of dimension mismatched");
+    // TODO: support slicing with changing the numbder of dimensions,
+    // e.g. when array_coord represents an element access to array(:,1,:)
+    // slice: the shape is 3D and the number of indices is 2 in this case.
     if (shapeTyRank != getIndices().size())
       return emitOpError("number of indices do not match dim rank");
   }
@@ -417,6 +423,93 @@ mlir::LogicalResult fir::ArrayCoorOp::verify() {
   return mlir::success();
 }
 
+// Pull in fir.embox and fir.rebox into fir.array_coor when possible.
+struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
+  using mlir::OpRewritePattern<fir::ArrayCoorOp>::OpRewritePattern;
+  mlir::LogicalResult
+  matchAndRewrite(fir::ArrayCoorOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Value memref = op.getMemref();
+    if (!mlir::isa<fir::BaseBoxType>(memref.getType()))
+      return mlir::failure();
+
+    mlir::Value boxedMemref, boxedShape, boxedSlice;
+    if (auto emboxOp =
+            mlir::dyn_cast_or_null<fir::EmboxOp>(memref.getDefiningOp())) {
+      boxedMemref = emboxOp.getMemref();
+      boxedShape = emboxOp.getShape();
+      boxedSlice = emboxOp.getSlice();
+      // If any of operands, that are not currently supported for migration
+      // to ArrayCoorOp, is present, don't rewrite.
+      if (!emboxOp.getTypeparams().empty() || emboxOp.getSourceBox() ||
+          emboxOp.getAccessMap())
+        return mlir::failure();
+    } else if (auto reboxOp = mlir::dyn_cast_or_null<fir::ReboxOp>(
+                   memref.getDefiningOp())) {
+      boxedMemref = reboxOp.getBox();
+      boxedShape = reboxOp.getShape();
+      boxedSlice = reboxOp.getSlice();
+    } else {
+      return mlir::failure();
+    }
+
+    // Slices changing the number of dimensions are not supported
+    // for array_coor yet.
+    unsigned origBoxRank;
+    if (mlir::isa<fir::BaseBoxType>(boxedMemref.getType()))
+      origBoxRank = fir::getBoxRank(boxedMemref.getType());
+    else if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(
+                 fir::unwrapRefType(boxedMemref.getType())))
+      origBoxRank = arrTy.getDimension();
+    else
+      return mlir::failure();
+
+    if (fir::getBoxRank(memref.getType()) != origBoxRank)
+      return mlir::failure();
+
+    // Slices with substring are not supported by array_coor.
+    if (boxedSlice)
+      if (auto sliceOp =
+              mlir::dyn_cast_or_null<fir::SliceOp>(boxedSlice.getDefiningOp()))
+        if (!sliceOp.getSubstr().empty())
+          return mlir::failure();
+
+    // If embox/rebox and array_coor have conflicting shapes or slices,
+    // do nothing.
+    if (op.getShape() && boxedShape && boxedShape != op.getShape())
+      return mlir::failure();
+    if (op.getSlice() && boxedSlice && boxedSlice != op.getSlice())
+      return mlir::failure();
+
+    // TODO: temporarily avoid producing array_coor with the shape shift
+    // and plain array reference (it seems to be a limitation of
+    // ArrayCoorOp verifier).
+    if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType())) {
+      if (boxedShape) {
+        if (mlir::isa<fir::ShiftType>(boxedShape.getType()))
+          return mlir::failure();
+      } else if (op.getShape() &&
+                 mlir::isa<fir::ShiftType>(op.getShape().getType())) {
+        return mlir::failure();
+      }
+    }
+
+    rewriter.modifyOpInPlace(op, [&]() {
+      op.getMemrefMutable().assign(boxedMemref);
+      if (boxedShape)
+        op.getShapeMutable().assign(boxedShape);
+      if (boxedSlice)
+        op.getSliceMutable().assign(boxedSlice);
+    });
+    return mlir::success();
+  }
+};
+
+void fir::ArrayCoorOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
+  patterns.add<SimplifyArrayCoorOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ArrayLoadOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/array-coor-canonicalization.fir b/flang/test/Fir/array-coor-canonicalization.fir
new file mode 100644
index 0000000000000..ec36b2e57b9f9
--- /dev/null
+++ b/flang/test/Fir/array-coor-canonicalization.fir
@@ -0,0 +1,121 @@
+// RUN: fir-opt --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @_QPtest1(
+// CHECK-SAME:                        %[[VAL_0:.*]]: !fir.ref<!fir.array<120x2xi32>> {fir.bindc_name = "u"}) {
+// CHECK:           %[[VAL_6:.*]] = fir.shape
+// CHECK:           %[[VAL_7:.*]] = fir.declare %[[VAL_0]](%[[VAL_6]])
+// CHECK:           %[[VAL_8:.*]] = fir.slice
+// CHECK:           fir.do_loop
+// CHECK:             fir.do_loop
+// CHECK:               %[[VAL_11:.*]] = fir.array_coor %[[VAL_7]](%[[VAL_6]]) {{\[}}%[[VAL_8]]]
+func.func @_QPtest1(%arg0: !fir.ref<!fir.array<120x2xi32>> {fir.bindc_name = "u"}) {
+  %c1 = arith.constant 1 : index
+  %c2_i32 = arith.constant 2 : i32
+  %c2 = arith.constant 2 : index
+  %c120 = arith.constant 120 : index
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = fir.shape %c120, %c2 : (index, index) -> !fir.shape<2>
+  %2 = fir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFtest1Eu"} : (!fir.ref<!fir.array<120x2xi32>>, !fir.shape<2>, !fir.dscope) -> !fir.ref<!fir.array<120x2xi32>>
+  %3 = fir.slice %c1, %c120, %c1, %c1, %c2, %c1 : (index, index, index, index, index, index) -> !fir.slice<2>
+  %4 = fir.embox %2(%1) [%3] : (!fir.ref<!fir.array<120x2xi32>>, !fir.shape<2>, !fir.slice<2>) -> !fir.box<!fir.array<120x2xi32>>
+  fir.do_loop %arg1 = %c1 to %c2 step %c1 unordered {
+    fir.do_loop %arg2 = %c1 to %c120 step %c1 unordered {
+      %5 = fir.array_coor %4 %arg2, %arg1 : (!fir.box<!fir.array<120x2xi32>>, index, index) -> !fir.ref<i32>
+      fir.store %c2_i32 to %5 : !fir.ref<i32>
+    }
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @_QPtest2(
+// CHECK-SAME:                        %[[VAL_0:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "u"}) {
+// CHECK:           %[[VAL_8:.*]] = fir.shift
+// CHECK:           %[[VAL_9:.*]] = fir.declare %[[VAL_0]](%[[VAL_8]])
+// CHECK:           fir.do_loop
+// CHECK:             fir.do_loop
+// CHECK:               %[[VAL_17:.*]] = fir.array_coor %[[VAL_9]](%[[VAL_8]])
+func.func @_QPtest2(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "u"}) {
+  %c9 = arith.constant 9 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c11 = arith.constant 11 : index
+  %c10 = arith.constant 10 : index
+  %c2_i32 = arith.constant 2 : i32
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = fir.shift %c10, %c11 : (index, index) -> !fir.shift<2>
+  %2 = fir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFtest2Eu"} : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>, !fir.dscope) -> !fir.box<!fir.array<?x?xi32>>
+  %3 = fir.rebox %2(%1) : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>) -> !fir.box<!fir.array<?x?xi32>>
+  %4:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+  %5:3 = fir.box_dims %3, %c1 : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+  fir.do_loop %arg1 = %c1 to %5#1 step %c1 unordered {
+    fir.do_loop %arg2 = %c1 to %4#1 step %c1 unordered {
+      %6 = arith.addi %arg2, %c9 : index
+      %7 = arith.addi %arg1, %c10 : index
+      %8 = fir.array_coor %3(%1) %6, %7 : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>, index, index) -> !fir.ref<i32>
+      fir.store %c2_i32 to %8 : !fir.ref<i32>
+    }
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @_QPtest3(
+// CHECK-SAME:                        %[[VAL_0:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "u"}) {
+// CHECK:           %[[VAL_10:.*]] = fir.shift
+// CHECK:           %[[VAL_11:.*]] = fir.declare %[[VAL_0]](%[[VAL_10]])
+// CHECK:           %[[VAL_12:.*]] = fir.slice
+// CHECK:           fir.do_loop
+// CHECK:             fir.do_loop
+// CHECK:               %[[VAL_15:.*]] = fir.array_coor %[[VAL_11]](%[[VAL_10]]) {{\[}}%[[VAL_12]]]
+func.func @_QPtest3(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "u"}) {
+  %c2 = arith.constant 2 : index
+  %c12 = arith.constant 12 : index
+  %c11 = arith.constant 11 : index
+  %c111 = arith.constant 111 : index
+  %c1 = arith.constant 1 : index
+  %c120 = arith.constant 120 : index
+  %c10 = arith.constant 10 : index
+  %c2_i32 = arith.constant 2 : i32
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = fir.shift %c10, %c11 : (index, index) -> !fir.shift<2>
+  %2 = fir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFtest3Eu"} : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>, !fir.dscope) -> !fir.box<!fir.array<?x?xi32>>
+  %3 = fir.rebox %2(%1) : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>) -> !fir.box<!fir.array<?x?xi32>>
+  %4 = fir.slice %c10, %c120, %c1, %c11, %c12, %c1 : (index, index, index, index, index, index) -> !fir.slice<2>
+  %5 = fir.rebox %3(%1) [%4] : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>, !fir.slice<2>) -> !fir.box<!fir.array<111x2xi32>>
+  fir.do_loop %arg1 = %c1 to %c2 step %c1 unordered {
+    fir.do_loop %arg2 = %c1 to %c111 step %c1 unordered {
+      %6 = fir.array_coor %5 %arg2, %arg1 : (!fir.box<!fir.array<111x2xi32>>, index, index) -> !fir.ref<i32>
+      fir.store %c2_i32 to %6 : !fir.ref<i32>
+    }
+  }
+  return
+}
+
+// TODO: fir.array_coor with slices changing the number of dimensions
+// is not supported yet.
+// CHECK-LABEL:   func.func @_QPtest4() {
+// CHECK:           %[[VAL_3:.*]] = fir.alloca !fir.array<100x100x100xi32> {bindc_name = "u", uniq_name = "_QFtest4Eu"}
+// CHECK:           %[[VAL_4:.*]] = fir.shape
+// CHECK:           %[[VAL_5:.*]] = fir.declare %[[VAL_3]](%[[VAL_4]]) {uniq_name = "_QFtest4Eu"} : (!fir.ref<!fir.array<100x100x100xi32>>, !fir.shape<3>) -> !fir.ref<!fir.array<100x100x100xi32>>
+// CHECK:           %[[VAL_7:.*]] = fir.slice
+// CHECK:           %[[VAL_8:.*]] = fir.embox %[[VAL_5]](%[[VAL_4]]) {{\[}}%[[VAL_7]]] : (!fir.ref<!fir.array<100x100x100xi32>>, !fir.shape<3>, !fir.slice<3>) -> !fir.box<!fir.array<100x100xi32>>
+// CHECK:           fir.do_loop
+// CHECK:             fir.do_loop
+// CHECK:               %[[VAL_11:.*]] = fir.array_coor %[[VAL_8]]
+func.func @_QPtest4() {
+  %c1 = arith.constant 1 : index
+  %c2_i32 = arith.constant 2 : i32
+  %c100 = arith.constant 100 : index
+  %0 = fir.alloca !fir.array<100x100x100xi32> {bindc_name = "u", uniq_name = "_QFtest4Eu"}
+  %1 = fir.shape %c100, %c100, %c100 : (index, index, index) -> !fir.shape<3>
+  %2 = fir.declare %0(%1) {uniq_name = "_QFtest4Eu"} : (!fir.ref<!fir.array<100x100x100xi32>>, !fir.shape<3>) -> !fir.ref<!fir.array<100x100x100xi32>>
+  %3 = fir.undefined index
+  %4 = fir.slice %c1, %c100, %c1, %c1, %3, %3, %c1, %c100, %c1 : (index, index, index, index, index, index, index, index, index) -> !fir.slice<3>
+  %5 = fir.embox %2(%1) [%4] : (!fir.ref<!fir.array<100x100x100xi32>>, !fir.shape<3>, !fir.slice<3>) -> !fir.box<!fir.array<100x100xi32>>
+  fir.do_loop %arg0 = %c1 to %c100 step %c1 unordered {
+    fir.do_loop %arg1 = %c1 to %c100 step %c1 unordered {
+      %6 = fir.array_coor %5 %arg1, %arg0 : (!fir.box<!fir.array<100x100xi32>>, index, index) -> !fir.ref<i32>
+      fir.store %c2_i32 to %6 : !fir.ref<i32>
+    }
+  }
+  return
+}
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 21323a5e657c9..80aebd5f8500d 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2084,7 +2084,7 @@ func.func @ext_array_coor1(%arg0: !fir.ref<!fir.array<?xi32>>) {
 // CHECK:         %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:         %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:         %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK:         %[[IDX:.*]] = llvm.sub %[[C0]], %[[C0]] overflow<nsw> : i64
+// CHECK:         %[[IDX:.*]] = llvm.sub %[[C0]], %[[C1]] overflow<nsw> : i64
 // CHECK:         %[[DIFF0:.*]] = llvm.mul %[[IDX]], %[[C0]] overflow<nsw> : i64
 // CHECK:         %[[ADJ:.*]] = llvm.sub %[[C0]], %[[C0]]  overflow<nsw> : i64
 // CHECK:         %[[DIFF1:.*]] = llvm.add %[[DIFF0]], %[[ADJ]] overflow<nsw> : i64
@@ -2153,7 +2153,7 @@ func.func @ext_array_coor4(%arg0: !fir.ref<!fir.array<100xi32>>) {
 // CHECK:         %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:         %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:         %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK:         %[[IDX:.*]] = llvm.sub %[[C1]], %[[C0]] overflow<nsw> : i64
+// CHECK:         %[[IDX:.*]] = llvm.sub %[[C1]], %[[C1_1]] overflow<nsw> : i64
 // CHECK:         %[[DIFF0:.*]] = llvm.mul %[[IDX]], %[[C1]] overflow<nsw> : i64
 // CHECK:         %[[ADJ:.*]] = llvm.sub %[[C10]], %[[C0]] overflow<nsw> : i64
 // CHECK:         %[[DIFF1:.*]] = llvm.add %[[DIFF0]], %[[ADJ]] overflow<nsw> : i64

``````````

</details>


https://github.com/llvm/llvm-project/pull/92858


More information about the flang-commits mailing list