[Mlir-commits] [mlir] [MLIR][SCF] Sink scf.if from scf.while before region into after region. (PR #165216)

Ming Yan llvmlistbot at llvm.org
Tue Nov 18 19:03:28 PST 2025


https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/165216

>From ecffe33bf9ddd95db18922ba60973dc5a62c81eb Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Mon, 27 Oct 2025 16:35:03 +0800
Subject: [PATCH 1/6] [MLIR][SCF] Sink scf.if from scf.while before region into
 after region.

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 134 +++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir |  37 +++++++
 2 files changed, 170 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9bd13f3236cfc..eaf17546e9281 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3546,6 +3546,137 @@ LogicalResult scf::WhileOp::verify() {
 }
 
 namespace {
+/// Move an scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  %res = scf.if %cond -> (...) {
+///    use(%additional_used_values)
+///    ... // then block
+///    scf.yield %then_value
+///  } else {
+///    scf.yield %else_value
+///  }
+///  scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+///    use(%res_arg)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+///    use(%additional_args)
+///    ... // if then block
+///    use(%then_value)
+///    ...
+struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto conditionOp =
+        cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+    // Check that the ifOp is directly before the conditionOp and that it
+    // matches the condition of the conditionOp. Also ensure that the ifOp has
+    // no else block with content, as that would complicate the transformation.
+    // TODO: support else blocks with content.
+    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+      return failure();
+
+    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+                                 *ifOp->user_begin() == conditionOp) &&
+                                    "ifOp has unexpected uses");
+
+    Location loc = op.getLoc();
+
+    // Replace uses of ifOp results in the conditionOp with the yielded values
+    // from the ifOp branches.
+    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+      auto it = llvm::find(ifOp->getResults(), arg);
+      if (it != ifOp->getResults().end()) {
+        size_t ifOpIdx = it.getIndex();
+        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+      }
+    }
+
+    SmallVector<Value> additionalUsedValues;
+    auto isValueUsedInsideIf = [&](Value val) {
+      return llvm::any_of(val.getUsers(), [&](Operation *user) {
+        return ifOp.getThenRegion().isAncestor(user->getParentRegion());
+      });
+    };
+
+    // Collect additional used values from before region.
+    for (Operation *it = ifOp->getPrevNode(); it != nullptr;
+         it = it->getPrevNode())
+      llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
+                    isValueUsedInsideIf);
+
+    llvm::copy_if(op.getBeforeArguments(),
+                  std::back_inserter(additionalUsedValues),
+                  isValueUsedInsideIf);
+
+    // Create new whileOp with additional used values as results.
+    auto additionalValueTypes = llvm::map_to_vector(
+        additionalUsedValues, [](Value val) { return val.getType(); });
+    size_t additionalValueSize = additionalUsedValues.size();
+    SmallVector<Type> newResultTypes(op.getResultTypes());
+    newResultTypes.append(additionalValueTypes);
+
+    auto newWhileOp =
+        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+    newWhileOp.getBefore().takeBody(op.getBefore());
+    newWhileOp.getAfter().takeBody(op.getAfter());
+    newWhileOp.getAfter().addArguments(
+        additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
+
+    SmallVector<Value> conditionArgs = conditionOp.getArgs();
+    llvm::append_range(conditionArgs, additionalUsedValues);
+
+    // Update conditionOp inside new whileOp before region.
+    rewriter.setInsertionPoint(conditionOp);
+    rewriter.replaceOpWithNewOp<scf::ConditionOp>(
+        conditionOp, conditionOp.getCondition(), conditionArgs);
+
+    // Replace uses of additional used values inside the ifOp then region with
+    // the whileOp after region arguments.
+    rewriter.replaceUsesWithIf(
+        additionalUsedValues,
+        newWhileOp.getAfterArguments().take_back(additionalValueSize),
+        [&](OpOperand &use) {
+          return ifOp.getThenRegion().isAncestor(
+              use.getOwner()->getParentRegion());
+        });
+
+    // Inline ifOp then region into new whileOp after region.
+    rewriter.eraseOp(ifOp.thenYield());
+    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+                               newWhileOp.getAfterBody()->begin());
+    rewriter.eraseOp(ifOp);
+    rewriter.replaceOp(op,
+                       newWhileOp->getResults().drop_back(additionalValueSize));
+    return success();
+  }
+};
+
 /// Replace uses of the condition within the do block with true, since otherwise
 /// the block would not be evaluated.
 ///
@@ -4258,7 +4389,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2bec63672e783..cfae3b34305de 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,6 +974,43 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 
 // -----
 
+// CHECK-LABEL: @while_move_if_down
+func.func @while_move_if_down() -> i32 {
+  %0 = scf.while () : () -> (i32) {
+    %additional_used_value = "test.get_some_value1" () : () -> (i32)
+    %else_value = "test.get_some_value2" () : () -> (i32)
+    %condition = "test.condition"() : () -> i1
+    %res = scf.if %condition -> (i32) {
+      "test.use1" (%additional_used_value) : (i32) -> ()
+      %then_value = "test.get_some_value3" () : () -> (i32)
+      scf.yield %then_value : i32
+    } else {
+      scf.yield %else_value : i32
+    }
+    scf.condition(%condition) %res : i32
+  } do {
+  ^bb0(%res_arg: i32):
+    "test.use2" (%res_arg) : (i32) -> ()
+    scf.yield
+  }
+  return %0 : i32
+}
+// CHECK-NEXT:      %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
+// CHECK-NEXT:        %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
+// CHECK-NEXT:        %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
+// CHECK-NEXT:        %[[VAL_2:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT:        scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
+// CHECK-NEXT:      } do {
+// CHECK-NEXT:      ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
+// CHECK-NEXT:        "test.use1"(%[[VAL_4]]) : (i32) -> ()
+// CHECK-NEXT:        %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
+// CHECK-NEXT:        "test.use2"(%[[VAL_5]]) : (i32) -> ()
+// CHECK-NEXT:        scf.yield
+// CHECK-NEXT:      }
+// CHECK-NEXT:      return %[[VAL_6:.*]]#0 : i32
+
+// -----
+
 // CHECK-LABEL: @while_cond_true
 func.func @while_cond_true() -> i1 {
   %0 = scf.while () : () -> i1 {

>From 19a042eab11e8ebe9251201c0fb09c3ab6c1ce7f Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 4 Nov 2025 11:33:48 +0800
Subject: [PATCH 2/6] Move the pattern into `populateUpliftWhileToForPattern`.

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 134 +-----------------
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 134 +++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir       |  37 -----
 mlir/test/Dialect/SCF/uplift-while.mlir       |  31 ++++
 4 files changed, 165 insertions(+), 171 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eaf17546e9281..9bd13f3236cfc 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3546,137 +3546,6 @@ LogicalResult scf::WhileOp::verify() {
 }
 
 namespace {
-/// Move an scf.if op that is directly before the scf.condition op in the while
-/// before region, and whose condition matches the condition of the
-/// scf.condition op, down into the while after region.
-///
-/// scf.while (..) : (...) -> ... {
-///  %additional_used_values = ...
-///  %cond = ...
-///  ...
-///  %res = scf.if %cond -> (...) {
-///    use(%additional_used_values)
-///    ... // then block
-///    scf.yield %then_value
-///  } else {
-///    scf.yield %else_value
-///  }
-///  scf.condition(%cond) %res, ...
-/// } do {
-/// ^bb0(%res_arg, ...):
-///    use(%res_arg)
-///    ...
-///
-/// becomes
-/// scf.while (..) : (...) -> ... {
-///  %additional_used_values = ...
-///  %cond = ...
-///  ...
-///  scf.condition(%cond) %else_value, ..., %additional_used_values
-/// } do {
-/// ^bb0(%res_arg ..., %additional_args): :
-///    use(%additional_args)
-///    ... // if then block
-///    use(%then_value)
-///    ...
-struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    auto conditionOp =
-        cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
-    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
-
-    // Check that the ifOp is directly before the conditionOp and that it
-    // matches the condition of the conditionOp. Also ensure that the ifOp has
-    // no else block with content, as that would complicate the transformation.
-    // TODO: support else blocks with content.
-    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
-        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
-      return failure();
-
-    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
-                                 *ifOp->user_begin() == conditionOp) &&
-                                    "ifOp has unexpected uses");
-
-    Location loc = op.getLoc();
-
-    // Replace uses of ifOp results in the conditionOp with the yielded values
-    // from the ifOp branches.
-    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
-      auto it = llvm::find(ifOp->getResults(), arg);
-      if (it != ifOp->getResults().end()) {
-        size_t ifOpIdx = it.getIndex();
-        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
-        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
-
-        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
-        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
-      }
-    }
-
-    SmallVector<Value> additionalUsedValues;
-    auto isValueUsedInsideIf = [&](Value val) {
-      return llvm::any_of(val.getUsers(), [&](Operation *user) {
-        return ifOp.getThenRegion().isAncestor(user->getParentRegion());
-      });
-    };
-
-    // Collect additional used values from before region.
-    for (Operation *it = ifOp->getPrevNode(); it != nullptr;
-         it = it->getPrevNode())
-      llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
-                    isValueUsedInsideIf);
-
-    llvm::copy_if(op.getBeforeArguments(),
-                  std::back_inserter(additionalUsedValues),
-                  isValueUsedInsideIf);
-
-    // Create new whileOp with additional used values as results.
-    auto additionalValueTypes = llvm::map_to_vector(
-        additionalUsedValues, [](Value val) { return val.getType(); });
-    size_t additionalValueSize = additionalUsedValues.size();
-    SmallVector<Type> newResultTypes(op.getResultTypes());
-    newResultTypes.append(additionalValueTypes);
-
-    auto newWhileOp =
-        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
-
-    newWhileOp.getBefore().takeBody(op.getBefore());
-    newWhileOp.getAfter().takeBody(op.getAfter());
-    newWhileOp.getAfter().addArguments(
-        additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
-
-    SmallVector<Value> conditionArgs = conditionOp.getArgs();
-    llvm::append_range(conditionArgs, additionalUsedValues);
-
-    // Update conditionOp inside new whileOp before region.
-    rewriter.setInsertionPoint(conditionOp);
-    rewriter.replaceOpWithNewOp<scf::ConditionOp>(
-        conditionOp, conditionOp.getCondition(), conditionArgs);
-
-    // Replace uses of additional used values inside the ifOp then region with
-    // the whileOp after region arguments.
-    rewriter.replaceUsesWithIf(
-        additionalUsedValues,
-        newWhileOp.getAfterArguments().take_back(additionalValueSize),
-        [&](OpOperand &use) {
-          return ifOp.getThenRegion().isAncestor(
-              use.getOwner()->getParentRegion());
-        });
-
-    // Inline ifOp then region into new whileOp after region.
-    rewriter.eraseOp(ifOp.thenYield());
-    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
-                               newWhileOp.getAfterBody()->begin());
-    rewriter.eraseOp(ifOp);
-    rewriter.replaceOp(op,
-                       newWhileOp->getResults().drop_back(additionalValueSize));
-    return success();
-  }
-};
-
 /// Replace uses of the condition within the do block with true, since otherwise
 /// the block would not be evaluated.
 ///
@@ -4389,8 +4258,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
-      context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ec1044aaa42ac..3021367e596f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -19,6 +19,137 @@
 using namespace mlir;
 
 namespace {
+/// Move a scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  %res = scf.if %cond -> (...) {
+///    use(%additional_used_values)
+///    ... // then block
+///    scf.yield %then_value
+///  } else {
+///    scf.yield %else_value
+///  }
+///  scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+///    use(%res_arg)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+///    use(%additional_args)
+///    ... // if then block
+///    use(%then_value)
+///    ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto conditionOp =
+        cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+    // Check that the ifOp is directly before the conditionOp and that it
+    // matches the condition of the conditionOp. Also ensure that the ifOp has
+    // no else block with content, as that would complicate the transformation.
+    // TODO: support else blocks with content.
+    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+      return failure();
+
+    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+                                 *ifOp->user_begin() == conditionOp) &&
+                                    "ifOp has unexpected uses");
+
+    Location loc = op.getLoc();
+
+    // Replace uses of ifOp results in the conditionOp with the yielded values
+    // from the ifOp branches.
+    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+      auto it = llvm::find(ifOp->getResults(), arg);
+      if (it != ifOp->getResults().end()) {
+        size_t ifOpIdx = it.getIndex();
+        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+      }
+    }
+
+    SmallVector<Value> additionalUsedValues;
+    auto isValueUsedInsideIf = [&](Value val) {
+      return llvm::any_of(val.getUsers(), [&](Operation *user) {
+        return ifOp.getThenRegion().isAncestor(user->getParentRegion());
+      });
+    };
+
+    // Collect additional used values from before region.
+    for (Operation *it = ifOp->getPrevNode(); it != nullptr;
+         it = it->getPrevNode())
+      llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
+                    isValueUsedInsideIf);
+
+    llvm::copy_if(op.getBeforeArguments(),
+                  std::back_inserter(additionalUsedValues),
+                  isValueUsedInsideIf);
+
+    // Create new whileOp with additional used values as results.
+    auto additionalValueTypes = llvm::map_to_vector(
+        additionalUsedValues, [](Value val) { return val.getType(); });
+    size_t additionalValueSize = additionalUsedValues.size();
+    SmallVector<Type> newResultTypes(op.getResultTypes());
+    newResultTypes.append(additionalValueTypes);
+
+    auto newWhileOp =
+        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+    newWhileOp.getBefore().takeBody(op.getBefore());
+    newWhileOp.getAfter().takeBody(op.getAfter());
+    newWhileOp.getAfter().addArguments(
+        additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
+
+    SmallVector<Value> conditionArgs = conditionOp.getArgs();
+    llvm::append_range(conditionArgs, additionalUsedValues);
+
+    // Update conditionOp inside new whileOp before region.
+    rewriter.setInsertionPoint(conditionOp);
+    rewriter.replaceOpWithNewOp<scf::ConditionOp>(
+        conditionOp, conditionOp.getCondition(), conditionArgs);
+
+    // Replace uses of additional used values inside the ifOp then region with
+    // the whileOp after region arguments.
+    rewriter.replaceUsesWithIf(
+        additionalUsedValues,
+        newWhileOp.getAfterArguments().take_back(additionalValueSize),
+        [&](OpOperand &use) {
+          return ifOp.getThenRegion().isAncestor(
+              use.getOwner()->getParentRegion());
+        });
+
+    // Inline ifOp then region into new whileOp after region.
+    rewriter.eraseOp(ifOp.thenYield());
+    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+                               newWhileOp.getAfterBody()->begin());
+    rewriter.eraseOp(ifOp);
+    rewriter.replaceOp(op,
+                       newWhileOp->getResults().drop_back(additionalValueSize));
+    return success();
+  }
+};
+
 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -267,5 +398,6 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 }
 
 void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
-  patterns.add<UpliftWhileOp>(patterns.getContext());
+  patterns.add<WhileMoveIfDown, UpliftWhileOp>(patterns.getContext());
+  scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index cfae3b34305de..2bec63672e783 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,43 +974,6 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 
 // -----
 
-// CHECK-LABEL: @while_move_if_down
-func.func @while_move_if_down() -> i32 {
-  %0 = scf.while () : () -> (i32) {
-    %additional_used_value = "test.get_some_value1" () : () -> (i32)
-    %else_value = "test.get_some_value2" () : () -> (i32)
-    %condition = "test.condition"() : () -> i1
-    %res = scf.if %condition -> (i32) {
-      "test.use1" (%additional_used_value) : (i32) -> ()
-      %then_value = "test.get_some_value3" () : () -> (i32)
-      scf.yield %then_value : i32
-    } else {
-      scf.yield %else_value : i32
-    }
-    scf.condition(%condition) %res : i32
-  } do {
-  ^bb0(%res_arg: i32):
-    "test.use2" (%res_arg) : (i32) -> ()
-    scf.yield
-  }
-  return %0 : i32
-}
-// CHECK-NEXT:      %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
-// CHECK-NEXT:        %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
-// CHECK-NEXT:        %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
-// CHECK-NEXT:        %[[VAL_2:.*]] = "test.condition"() : () -> i1
-// CHECK-NEXT:        scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
-// CHECK-NEXT:      } do {
-// CHECK-NEXT:      ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
-// CHECK-NEXT:        "test.use1"(%[[VAL_4]]) : (i32) -> ()
-// CHECK-NEXT:        %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
-// CHECK-NEXT:        "test.use2"(%[[VAL_5]]) : (i32) -> ()
-// CHECK-NEXT:        scf.yield
-// CHECK-NEXT:      }
-// CHECK-NEXT:      return %[[VAL_6:.*]]#0 : i32
-
-// -----
-
 // CHECK-LABEL: @while_cond_true
 func.func @while_cond_true() -> i1 {
   %0 = scf.while () : () -> i1 {
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index cbe2ce5076ad2..736112824c515 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,3 +185,34 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
 //       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
 //       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
 //       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
+  %c1 = arith.constant 1 : index
+  %1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
+    %2 = arith.cmpi slt, %iv, %upper : index
+    %3:2 = scf.if %2 -> (index, i32) {
+      %4 = "test.test"(%iter) : (i32) -> i32
+      %5 = arith.addi %iv, %c1 : index
+      scf.yield %5, %4 : index, i32
+    } else {
+      scf.yield %iv, %iter : index, i32
+    }
+    scf.condition(%2) %3#0, %3#1 : index, i32
+  } do {
+  ^bb0(%arg0: index, %arg1: i32):
+    scf.yield %arg0, %arg1 : index, i32
+  }
+  return %1#1 : i32
+}
+
+// CHECK-LABEL:   func.func @uplift_while(
+// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
+// CHECK:             %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
+// CHECK:             scf.yield %[[VAL_2]] : i32
+// CHECK:           }
+// CHECK:           return %[[FOR_0]] : i32
+// CHECK:         }

>From d76de865a22eb53d385c5aabeb21303a933021ab Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Wed, 5 Nov 2025 11:33:44 +0800
Subject: [PATCH 3/6] Use `visitUsedValuesDefinedAbove` to simplify the code.

---
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 23 ++++++-------------
 1 file changed, 7 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 3021367e596f7..0218339ee321a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 using namespace mlir;
 
@@ -89,22 +90,12 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
       }
     }
 
-    SmallVector<Value> additionalUsedValues;
-    auto isValueUsedInsideIf = [&](Value val) {
-      return llvm::any_of(val.getUsers(), [&](Operation *user) {
-        return ifOp.getThenRegion().isAncestor(user->getParentRegion());
-      });
-    };
-
     // Collect additional used values from before region.
-    for (Operation *it = ifOp->getPrevNode(); it != nullptr;
-         it = it->getPrevNode())
-      llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
-                    isValueUsedInsideIf);
-
-    llvm::copy_if(op.getBeforeArguments(),
-                  std::back_inserter(additionalUsedValues),
-                  isValueUsedInsideIf);
+    SetVector<Value> additionalUsedValues;
+    visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
+      if (op.getBefore().isAncestor(operand->get().getParentRegion()))
+        additionalUsedValues.insert(operand->get());
+    });
 
     // Create new whileOp with additional used values as results.
     auto additionalValueTypes = llvm::map_to_vector(
@@ -132,7 +123,7 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
     // Replace uses of additional used values inside the ifOp then region with
     // the whileOp after region arguments.
     rewriter.replaceUsesWithIf(
-        additionalUsedValues,
+        additionalUsedValues.takeVector(),
         newWhileOp.getAfterArguments().take_back(additionalValueSize),
         [&](OpOperand &use) {
           return ifOp.getThenRegion().isAncestor(

>From 56dddbdece5d2f8144a11914167760e3ba742ad8 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 18 Nov 2025 16:40:14 +0800
Subject: [PATCH 4/6] Add more constraints to simplify the transformation
 process and move it into canonicalization. We do not have an appropriate cost
 model for evaluation, and adding more constraints can ensure that it is
 always beneficial and does not introduce excessive compile-time overhead.

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  79 ++++++++++-
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 125 +-----------------
 mlir/test/Dialect/SCF/canonicalize.mlir       |  32 +++++
 mlir/test/Dialect/SCF/uplift-while.mlir       |  31 -----
 4 files changed, 111 insertions(+), 156 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9bd13f3236cfc..ed579ca68e949 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3546,6 +3546,82 @@ LogicalResult scf::WhileOp::verify() {
 }
 
 namespace {
+/// Move an scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (%init) : (...) -> ... {
+///   %cond = ...
+///   %res = scf.if %cond -> (...) {
+///     use1(%init)
+///     %then_val = ...
+///      ... // then block
+///     scf.yield %then_val
+///   } else {
+///     scf.yield %init
+///   }
+///   scf.condition(%cond) %res
+/// } do {
+/// ^bb0(%arg):
+///   use2(%arg)
+///    ...
+///
+/// becomes
+/// scf.while (%init) : (...) -> ... {
+///   %cond = ...
+///   scf.condition(%cond) %init
+/// } do {
+/// ^bb0(%arg): :
+///   use1(%arg)
+///    ... // if then block
+///   %then_val = ...
+///   use2(%then_val)
+///    ...
+struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    // Check that the first operation in the before region is an scf.if whose
+    // condition matches the condition of the scf.condition op.
+    Operation &condOp = op.getBeforeBody()->front();
+    if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2))
+      return failure();
+
+    Value condVal = condOp.getResult(0);
+    auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode());
+    if (condOp.getNumResults() != 1 || !ifOp || ifOp.getCondition() != condVal)
+      return failure();
+
+    auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode());
+    if (!term || term.getCondition() != condVal)
+      return failure();
+
+    // Check that if results and else yield operands match the scf.condition op
+    // arguments and while before arguments respectively.
+    if (!llvm::equal(ifOp->getResults(), term.getArgs()) ||
+        !llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments()))
+      return failure();
+
+    // Update uses and move the if op into the after region.
+    rewriter.replaceAllUsesWith(op.getAfterArguments(),
+                                ifOp.thenYield()->getOperands());
+    rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(),
+                               [&](OpOperand &use) {
+                                 return ifOp.getThenRegion().isAncestor(
+                                     use.getOwner()->getParentRegion());
+                               });
+    rewriter.modifyOpInPlace(
+        term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); });
+
+    rewriter.eraseOp(ifOp.thenYield());
+    rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(),
+                               op.getAfterBody()->begin());
+    rewriter.eraseOp(ifOp);
+    return success();
+  }
+};
+
 /// Replace uses of the condition within the do block with true, since otherwise
 /// the block would not be evaluated.
 ///
@@ -4258,7 +4334,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 0218339ee321a..ec1044aaa42ac 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -15,132 +15,10 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/RegionUtils.h"
 
 using namespace mlir;
 
 namespace {
-/// Move a scf.if op that is directly before the scf.condition op in the while
-/// before region, and whose condition matches the condition of the
-/// scf.condition op, down into the while after region.
-///
-/// scf.while (..) : (...) -> ... {
-///  %additional_used_values = ...
-///  %cond = ...
-///  ...
-///  %res = scf.if %cond -> (...) {
-///    use(%additional_used_values)
-///    ... // then block
-///    scf.yield %then_value
-///  } else {
-///    scf.yield %else_value
-///  }
-///  scf.condition(%cond) %res, ...
-/// } do {
-/// ^bb0(%res_arg, ...):
-///    use(%res_arg)
-///    ...
-///
-/// becomes
-/// scf.while (..) : (...) -> ... {
-///  %additional_used_values = ...
-///  %cond = ...
-///  ...
-///  scf.condition(%cond) %else_value, ..., %additional_used_values
-/// } do {
-/// ^bb0(%res_arg ..., %additional_args): :
-///    use(%additional_args)
-///    ... // if then block
-///    use(%then_value)
-///    ...
-struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
-  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(scf::WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    auto conditionOp =
-        cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
-    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
-
-    // Check that the ifOp is directly before the conditionOp and that it
-    // matches the condition of the conditionOp. Also ensure that the ifOp has
-    // no else block with content, as that would complicate the transformation.
-    // TODO: support else blocks with content.
-    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
-        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
-      return failure();
-
-    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
-                                 *ifOp->user_begin() == conditionOp) &&
-                                    "ifOp has unexpected uses");
-
-    Location loc = op.getLoc();
-
-    // Replace uses of ifOp results in the conditionOp with the yielded values
-    // from the ifOp branches.
-    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
-      auto it = llvm::find(ifOp->getResults(), arg);
-      if (it != ifOp->getResults().end()) {
-        size_t ifOpIdx = it.getIndex();
-        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
-        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
-
-        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
-        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
-      }
-    }
-
-    // Collect additional used values from before region.
-    SetVector<Value> additionalUsedValues;
-    visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
-      if (op.getBefore().isAncestor(operand->get().getParentRegion()))
-        additionalUsedValues.insert(operand->get());
-    });
-
-    // Create new whileOp with additional used values as results.
-    auto additionalValueTypes = llvm::map_to_vector(
-        additionalUsedValues, [](Value val) { return val.getType(); });
-    size_t additionalValueSize = additionalUsedValues.size();
-    SmallVector<Type> newResultTypes(op.getResultTypes());
-    newResultTypes.append(additionalValueTypes);
-
-    auto newWhileOp =
-        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
-
-    newWhileOp.getBefore().takeBody(op.getBefore());
-    newWhileOp.getAfter().takeBody(op.getAfter());
-    newWhileOp.getAfter().addArguments(
-        additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
-
-    SmallVector<Value> conditionArgs = conditionOp.getArgs();
-    llvm::append_range(conditionArgs, additionalUsedValues);
-
-    // Update conditionOp inside new whileOp before region.
-    rewriter.setInsertionPoint(conditionOp);
-    rewriter.replaceOpWithNewOp<scf::ConditionOp>(
-        conditionOp, conditionOp.getCondition(), conditionArgs);
-
-    // Replace uses of additional used values inside the ifOp then region with
-    // the whileOp after region arguments.
-    rewriter.replaceUsesWithIf(
-        additionalUsedValues.takeVector(),
-        newWhileOp.getAfterArguments().take_back(additionalValueSize),
-        [&](OpOperand &use) {
-          return ifOp.getThenRegion().isAncestor(
-              use.getOwner()->getParentRegion());
-        });
-
-    // Inline ifOp then region into new whileOp after region.
-    rewriter.eraseOp(ifOp.thenYield());
-    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
-                               newWhileOp.getAfterBody()->begin());
-    rewriter.eraseOp(ifOp);
-    rewriter.replaceOp(op,
-                       newWhileOp->getResults().drop_back(additionalValueSize));
-    return success();
-  }
-};
-
 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -389,6 +267,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 }
 
 void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
-  patterns.add<WhileMoveIfDown, UpliftWhileOp>(patterns.getContext());
-  scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext());
+  patterns.add<UpliftWhileOp>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2bec63672e783..d3bd2dfda05f2 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,6 +974,38 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 
 // -----
 
+// CHECK-LABEL: @while_move_if_down
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: i32)
+func.func @while_move_if_down(%arg0: index, %arg1: i32) -> i32 {
+  %0:2 = scf.while (%init0 = %arg0, %init1 = %arg1) : (index, i32) -> (index, i32) {
+    %condition = "test.condition"() : () -> i1
+    %res:2 = scf.if %condition -> (index, i32) {
+      %then_val:2 = "test.use1"(%init0, %init1) : (index, i32) -> (i32, index)
+      scf.yield %then_val#1, %then_val#0 : index, i32
+    } else {
+      scf.yield %init0, %init1 : index, i32
+    }
+    scf.condition(%condition) %res#0, %res#1 : index, i32
+  } do {
+  ^bb0(%arg2: index, %arg3: i32):
+    %1:2 = "test.use2"(%arg2, %arg3) : (index, i32) -> (i32, index)
+    scf.yield %1#1, %1#0 : index, i32
+  }
+  return %0#1 : i32
+}
+// CHECK-NEXT:      %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[ARG0]], %[[VAL_1:.*]] = %[[ARG1]]) : (index, i32) -> (index, i32) {
+// CHECK-NEXT:        %[[VAL_2:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT:        scf.condition(%[[VAL_2]]) %[[VAL_0]], %[[VAL_1]] : index, i32
+// CHECK-NEXT:      } do {
+// CHECK-NEXT:      ^bb0(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: i32):
+// CHECK-NEXT:        %[[VAL_5:.*]]:2 = "test.use1"(%[[VAL_3]], %[[VAL_4]]) : (index, i32) -> (i32, index)
+// CHECK-NEXT:        %[[VAL_6:.*]]:2 = "test.use2"(%[[VAL_5]]#1, %[[VAL_5]]#0) : (index, i32) -> (i32, index)
+// CHECK-NEXT:        scf.yield %[[VAL_6]]#1, %[[VAL_6]]#0 : index, i32
+// CHECK-NEXT:      }
+// CHECK-NEXT:      return %[[VAL_7:.*]]#1 : i32
+
+// -----
+
 // CHECK-LABEL: @while_cond_true
 func.func @while_cond_true() -> i1 {
   %0 = scf.while () : () -> i1 {
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 736112824c515..cbe2ce5076ad2 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,34 +185,3 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
 //       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
 //       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
 //       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
-
-// -----
-
-func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
-  %c1 = arith.constant 1 : index
-  %1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
-    %2 = arith.cmpi slt, %iv, %upper : index
-    %3:2 = scf.if %2 -> (index, i32) {
-      %4 = "test.test"(%iter) : (i32) -> i32
-      %5 = arith.addi %iv, %c1 : index
-      scf.yield %5, %4 : index, i32
-    } else {
-      scf.yield %iv, %iter : index, i32
-    }
-    scf.condition(%2) %3#0, %3#1 : index, i32
-  } do {
-  ^bb0(%arg0: index, %arg1: i32):
-    scf.yield %arg0, %arg1 : index, i32
-  }
-  return %1#1 : i32
-}
-
-// CHECK-LABEL:   func.func @uplift_while(
-// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
-// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : index
-// CHECK:           %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
-// CHECK:             %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
-// CHECK:             scf.yield %[[VAL_2]] : i32
-// CHECK:           }
-// CHECK:           return %[[FOR_0]] : i32
-// CHECK:         }

>From 9f6043e01e667e647ea1f50895d91baed06a039b Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 18 Nov 2025 16:52:03 +0800
Subject: [PATCH 5/6] clean up.

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ed579ca68e949..e32d3ffee7d0b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3590,7 +3590,7 @@ struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
 
     Value condVal = condOp.getResult(0);
     auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode());
-    if (condOp.getNumResults() != 1 || !ifOp || ifOp.getCondition() != condVal)
+    if (!ifOp || ifOp.getCondition() != condVal)
       return failure();
 
     auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode());

>From a2899306ab18307252b7d19e69a39573d26f6487 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Wed, 19 Nov 2025 11:02:38 +0800
Subject: [PATCH 6/6] Update the comments.

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e32d3ffee7d0b..08a640e210237 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3582,8 +3582,9 @@ struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
 
   LogicalResult matchAndRewrite(WhileOp op,
                                 PatternRewriter &rewriter) const override {
-    // Check that the first operation in the before region is an scf.if whose
-    // condition matches the condition of the scf.condition op.
+    // Check that the first opeation produces one result and that result must
+    // have exactly two uses (these two uses come from the `scf.if` and
+    // `scf.condition` operations).
     Operation &condOp = op.getBeforeBody()->front();
     if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2))
       return failure();



More information about the Mlir-commits mailing list