[flang-commits] [flang] [Flang] Fix for replacing loop uses in LoopVersioning pass (PR #77899)

David Green via flang-commits flang-commits at lists.llvm.org
Fri Jan 12 01:48:16 PST 2024


https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/77899

The added test case has a loop that is versioned, which has a use of the loop in an if block after the loop. The current code replaces all uses of the loop with the new version If, but only if the parent blocks match. As far as I can see it should be safe to replace all the uses, then construct the result for the If with op.op.

>From b7fefc9aa3c202d1fb7d2ffaeaa621686347b197 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Fri, 12 Jan 2024 09:45:39 +0000
Subject: [PATCH] [Flang] Fix for replacing loop uses in LoopVersioning pass

The added test case has a loop that is versioned, which has a use of the loop
in an if block after the loop. The current code replaces all uses of the loop
with the new version If, but only if the parent blocks match. As far as I can
see it should be safe to replace all the uses, then construct the result for
the If with op.op.
---
 .../Optimizer/Transforms/LoopVersioning.cpp   | 12 +--
 flang/test/Transforms/loop-versioning.fir     | 80 +++++++++++++++++++
 2 files changed, 81 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index 5d6a23b86d06ef..30a064ab8851f5 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -146,16 +146,6 @@ struct ArgsUsageInLoop {
 };
 } // namespace
 
-/// @c replaceOuterUses - replace uses outside of @c op with result of @c
-/// outerOp
-static void replaceOuterUses(mlir::Operation *op, mlir::Operation *outerOp) {
-  const mlir::Operation *outerParent = outerOp->getParentOp();
-  op->replaceUsesWithIf(outerOp, [&](mlir::OpOperand &operand) {
-    mlir::Operation *owner = operand.getOwner();
-    return outerParent == owner->getParentOp();
-  });
-}
-
 static fir::SequenceType getAsSequenceType(mlir::Value *v) {
   mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType()));
   return argTy.dyn_cast<fir::SequenceType>();
@@ -544,7 +534,7 @@ void LoopVersioningPass::runOnOperation() {
 
     // Add the original loop in the else-side of the if operation.
     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-    replaceOuterUses(op.op, ifOp);
+    op.op->replaceAllUsesWith(ifOp);
     op.op->remove();
     builder.insert(op.op);
     // Rely on "cloned loop has results, so original loop also has results".
diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir
index 6313bc2ac36a78..2a94ec213fd992 100644
--- a/flang/test/Transforms/loop-versioning.fir
+++ b/flang/test/Transforms/loop-versioning.fir
@@ -1478,4 +1478,84 @@ func.func @sum1drebox(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"},
 // CHECK-NOT: fir.if
 
 
+// Check for a use in a different block (%12 = do_loop is used inside the if %14 block)
+func.func @minloc(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "x"}, %arg1: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "mask"}) -> f32 {
+  %c2147483647_i32 = arith.constant 2147483647 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c5_i32 = arith.constant 5 : i32
+  %c5 = arith.constant 5 : index
+  %c1 = arith.constant 1 : index
+  %0 = fir.alloca i32
+  %1 = fir.alloca !fir.array<1xi32>
+  %2 = fir.declare %arg1 {uniq_name = "_QFtestEmask"} : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %3 = fir.rebox %2 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %4 = fir.alloca f32 {bindc_name = "test", uniq_name = "_QFtestEtest"}
+  %5 = fir.declare %4 {uniq_name = "_QFtestEtest"} : (!fir.ref<f32>) -> !fir.ref<f32>
+  %6 = fir.declare %arg0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %7 = fir.rebox %6 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
+  %8 = fir.shape %c1 : (index) -> !fir.shape<1>
+  %9 = fir.array_coor %1(%8) %c1 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+  fir.store %c0_i32 to %9 : !fir.ref<i32>
+  fir.store %c0_i32 to %0 : !fir.ref<i32>
+  %10:3 = fir.box_dims %7, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+  %11 = arith.subi %10#1, %c1 : index
+  %12 = fir.do_loop %arg2 = %c0 to %11 step %c1 iter_args(%arg3 = %c2147483647_i32) -> (i32) {
+    %18 = arith.addi %arg2, %c1 : index
+    %19 = fir.array_coor %3 %18 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %20 = fir.load %19 : !fir.ref<i32>
+    %21 = arith.cmpi sge, %20, %c5_i32 : i32
+    %22 = fir.if %21 -> (i32) {
+      fir.store %c1_i32 to %0 : !fir.ref<i32>
+      %23 = arith.subi %10#0, %c1 : index
+      %24 = arith.addi %18, %23 : index
+      %25 = fir.array_coor %7 %24 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+      %26 = fir.load %25 : !fir.ref<i32>
+      %27 = arith.cmpi slt, %26, %arg3 : i32
+      %28 = fir.if %27 -> (i32) {
+        %29 = fir.convert %18 : (index) -> i32
+        fir.store %29 to %9 : !fir.ref<i32>
+        fir.result %26 : i32
+      } else {
+        fir.result %arg3 : i32
+      }
+      fir.result %28 : i32
+    } else {
+      fir.result %arg3 : i32
+    }
+    fir.result %22 : i32
+  }
+  %13 = fir.load %0 : !fir.ref<i32>
+  %14 = arith.cmpi eq, %13, %c1_i32 : i32
+  fir.if %14 {
+    %18 = arith.cmpi eq, %12, %c2147483647_i32 : i32
+    fir.if %18 {
+      %19 = fir.array_coor %1(%8) %c0 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+      fir.store %c1_i32 to %19 : !fir.ref<i32>
+    }
+  }
+  %15 = fir.slice %c5, %c5, %c1 : (index, index, index) -> !fir.slice<1>
+  %16 = fir.rebox %7 [%15] : (!fir.box<!fir.array<?xi32>>, !fir.slice<1>) -> !fir.box<!fir.array<1xi32>>
+  fir.do_loop %arg2 = %c1 to %c1 step %c1 unordered {
+    %18 = fir.array_coor %1(%8) %arg2 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+    %19 = fir.load %18 : !fir.ref<i32>
+    %20 = fir.array_coor %16 %arg2 : (!fir.box<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+    fir.store %19 to %20 : !fir.ref<i32>
+  }
+  %17 = fir.load %5 : !fir.ref<f32>
+  return %17 : f32
+}
+// CHECK-LABEL: func @minloc
+// CHECK: %[[V17:.*]] = fir.if %{{.*}} -> (i32) {
+// CHECK: %[[V27:.*]] = fir.do_loop
+// CHECK: fir.result %[[V27]] : i32
+// CHECK: } else {
+// CHECK: %[[V23:.*]] = fir.do_loop
+// CHECK: fir.result %[[V23]] : i32
+// CHECK: fir.if %{{.*}} {
+// CHECL: {{.*}} = arith.cmpi eq, %[[V17]], %c2147483647_i32
+
+
+
 } // End module



More information about the flang-commits mailing list