[flang-commits] [flang] [Flang] Fix for replacing loop uses in LoopVersioning pass (PR #77899)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 12 01:48:42 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

The added test case has a loop that is versioned, which has a use of the loop in an if block after the loop. The current code replaces all uses of the loop with the new version If, but only if the parent blocks match. As far as I can see it should be safe to replace all the uses, then construct the result for the If with op.op.

---
Full diff: https://github.com/llvm/llvm-project/pull/77899.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/LoopVersioning.cpp (+1-11) 
- (modified) flang/test/Transforms/loop-versioning.fir (+80) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index 5d6a23b86d06ef..30a064ab8851f5 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -146,16 +146,6 @@ struct ArgsUsageInLoop {
 };
 } // namespace
 
-/// @c replaceOuterUses - replace uses outside of @c op with result of @c
-/// outerOp
-static void replaceOuterUses(mlir::Operation *op, mlir::Operation *outerOp) {
-  const mlir::Operation *outerParent = outerOp->getParentOp();
-  op->replaceUsesWithIf(outerOp, [&](mlir::OpOperand &operand) {
-    mlir::Operation *owner = operand.getOwner();
-    return outerParent == owner->getParentOp();
-  });
-}
-
 static fir::SequenceType getAsSequenceType(mlir::Value *v) {
   mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType()));
   return argTy.dyn_cast<fir::SequenceType>();
@@ -544,7 +534,7 @@ void LoopVersioningPass::runOnOperation() {
 
     // Add the original loop in the else-side of the if operation.
     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-    replaceOuterUses(op.op, ifOp);
+    op.op->replaceAllUsesWith(ifOp);
     op.op->remove();
     builder.insert(op.op);
     // Rely on "cloned loop has results, so original loop also has results".
diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir
index 6313bc2ac36a78..2a94ec213fd992 100644
--- a/flang/test/Transforms/loop-versioning.fir
+++ b/flang/test/Transforms/loop-versioning.fir
@@ -1478,4 +1478,84 @@ func.func @sum1drebox(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"},
 // CHECK-NOT: fir.if
 
 
+// Check for a use in a different block (%12 = do_loop is used inside the if %14 block)
+func.func @minloc(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "x"}, %arg1: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "mask"}) -> f32 {
+  %c2147483647_i32 = arith.constant 2147483647 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c5_i32 = arith.constant 5 : i32
+  %c5 = arith.constant 5 : index
+  %c1 = arith.constant 1 : index
+  %0 = fir.alloca i32
+  %1 = fir.alloca !fir.array<1xi32>
+  %2 = fir.declare %arg1 {uniq_name = "_QFtestEmask"} : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %3 = fir.rebox %2 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %4 = fir.alloca f32 {bindc_name = "test", uniq_name = "_QFtestEtest"}
+  %5 = fir.declare %4 {uniq_name = "_QFtestEtest"} : (!fir.ref<f32>) -> !fir.ref<f32>
+  %6 = fir.declare %arg0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %7 = fir.rebox %6 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %8 = fir.shape %c1 : (index) -> !fir.shape<1>
+  %9 = fir.array_coor %1(%8) %c1 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+  fir.store %c0_i32 to %9 : !fir.ref<i32>
+  fir.store %c0_i32 to %0 : !fir.ref<i32>
+  %10:3 = fir.box_dims %7, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+  %11 = arith.subi %10#1, %c1 : index
+  %12 = fir.do_loop %arg2 = %c0 to %11 step %c1 iter_args(%arg3 = %c2147483647_i32) -> (i32) {
+    %18 = arith.addi %arg2, %c1 : index
+    %19 = fir.array_coor %3 %18 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %20 = fir.load %19 : !fir.ref<i32>
+    %21 = arith.cmpi sge, %20, %c5_i32 : i32
+    %22 = fir.if %21 -> (i32) {
+      fir.store %c1_i32 to %0 : !fir.ref<i32>
+      %23 = arith.subi %10#0, %c1 : index
+      %24 = arith.addi %18, %23 : index
+      %25 = fir.array_coor %7 %24 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+      %26 = fir.load %25 : !fir.ref<i32>
+      %27 = arith.cmpi slt, %26, %arg3 : i32
+      %28 = fir.if %27 -> (i32) {
+        %29 = fir.convert %18 : (index) -> i32
+        fir.store %29 to %9 : !fir.ref<i32>
+        fir.result %26 : i32
+      } else {
+        fir.result %arg3 : i32
+      }
+      fir.result %28 : i32
+    } else {
+      fir.result %arg3 : i32
+    }
+    fir.result %22 : i32
+  }
+  %13 = fir.load %0 : !fir.ref<i32>
+  %14 = arith.cmpi eq, %13, %c1_i32 : i32
+  fir.if %14 {
+    %18 = arith.cmpi eq, %12, %c2147483647_i32 : i32
+    fir.if %18 {
+      %19 = fir.array_coor %1(%8) %c0 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+      fir.store %c1_i32 to %19 : !fir.ref<i32>
+    }
+  }
+  %15 = fir.slice %c5, %c5, %c1 : (index, index, index) -> !fir.slice<1>
+  %16 = fir.rebox %7 [%15] : (!fir.box<!fir.array<?xi32>>, !fir.slice<1>) -> !fir.box<!fir.array<1xi32>>
+  fir.do_loop %arg2 = %c1 to %c1 step %c1 unordered {
+    %18 = fir.array_coor %1(%8) %arg2 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    %19 = fir.load %18 : !fir.ref<i32>
+    %20 = fir.array_coor %16 %arg2 : (!fir.box<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+    fir.store %19 to %20 : !fir.ref<i32>
+  }
+  %17 = fir.load %5 : !fir.ref<f32>
+  return %17 : f32
+}
+// CHECK-LABEL: func @minloc
+// CHECK: %[[V17:.*]] = fir.if %{{.*}} -> (i32) {
+// CHECK: %[[V27:.*]] = fir.do_loop
+// CHECK: fir.result %[[V27]] : i32
+// CHECK: } else {
+// CHECK: %[[V23:.*]] = fir.do_loop
+// CHECK: fir.result %[[V23]] : i32
+// CHECK: fir.if %{{.*}} {
+// CHECL: {{.*}} = arith.cmpi eq, %[[V17]], %c2147483647_i32
+
+
+
 } // End module

``````````

</details>


https://github.com/llvm/llvm-project/pull/77899


More information about the flang-commits mailing list