[Mlir-commits] [mlir] [ mlir][scf] Allow 'ult'/'ugt' in uplift (PR #139911)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 14 07:37:28 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: None (darkbuck)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/139911.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+11-18) 
- (modified) mlir/test/Dialect/SCF/uplift-while.mlir (+64) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ebe718ae4fb61..0fabaf6e63ee4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -91,9 +91,10 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 
   using Pred = arith::CmpIPredicate;
   Pred predicate = cmp.getPredicate();
-  if (predicate != Pred::slt && predicate != Pred::sgt)
+  if (predicate != Pred::slt && predicate != Pred::sgt &&
+      predicate != Pred::ult && predicate != Pred::ugt)
     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
-      diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+      diag << "Expected 'slt'/'ult' or 'sgt'/'ugt' predicate: " << *cmp;
     });
 
   BlockArgument inductionVar;
@@ -103,24 +104,16 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   // Check if cmp has a suitable form. One of the arguments must be a `before`
   // block arg, other must be defined outside `scf.while` and will be treated
   // as upper bound.
-  for (bool reverse : {false, true}) {
-    auto expectedPred = reverse ? Pred::sgt : Pred::slt;
-    if (cmp.getPredicate() != expectedPred)
-      continue;
-
-    auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
-    auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
-
-    auto blockArg = dyn_cast<BlockArgument>(arg1);
-    if (!blockArg || blockArg.getOwner() != beforeBody)
-      continue;
-
-    if (!dom.properlyDominates(arg2, loop))
-      continue;
-
+  auto arg1 = cmp.getLhs();
+  auto arg2 = cmp.getRhs();
+  if (predicate == Pred::sgt || predicate == Pred::ugt)
+    std::swap(arg1, arg2);
+
+  auto blockArg = dyn_cast<BlockArgument>(arg1);
+  if (blockArg && blockArg.getOwner() == beforeBody &&
+      dom.properlyDominates(arg2, loop)) {
     inductionVar = blockArg;
     ub = arg2;
-    break;
   }
 
   if (!inductionVar)
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index cbe2ce5076ad2..f11f5ab28d707 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,3 +185,67 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
 //       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
 //       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
 //       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi ult, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi ugt, %arg1, %arg3 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/139911


More information about the Mlir-commits mailing list