[Mlir-commits] [mlir] [MLIR][SCF] Sink scf.if from scf.while before region into after region. (PR #165216)

Ming Yan llvmlistbot at llvm.org
Mon Nov 3 19:49:40 PST 2025


https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/165216

>From ecffe33bf9ddd95db18922ba60973dc5a62c81eb Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Mon, 27 Oct 2025 16:35:03 +0800
Subject: [PATCH 1/2] [MLIR][SCF] Sink scf.if from scf.while before region into
 after region.

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 134 +++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir |  37 +++++++
 2 files changed, 170 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9bd13f3236cfc..eaf17546e9281 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3546,6 +3546,137 @@ LogicalResult scf::WhileOp::verify() {
 }
 
 namespace {
+/// Move an scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  %res = scf.if %cond -> (...) {
+///    use(%additional_used_values)
+///    ... // then block
+///    scf.yield %then_value
+///  } else {
+///    scf.yield %else_value
+///  }
+///  scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+///    use(%res_arg)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+///    use(%additional_args)
+///    ... // if then block
+///    use(%then_value)
+///    ...
+struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto conditionOp =
+        cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+    // Check that the ifOp is directly before the conditionOp and that it
+    // matches the condition of the conditionOp. Also ensure that the ifOp has
+    // no else block with content, as that would complicate the transformation.
+    // TODO: support else blocks with content.
+    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+      return failure();
+
+    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+                                 *ifOp->user_begin() == conditionOp) &&
+                                    "ifOp has unexpected uses");
+
+    Location loc = op.getLoc();
+
+    // Replace uses of ifOp results in the conditionOp with the yielded values
+    // from the ifOp branches.
+    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+      auto it = llvm::find(ifOp->getResults(), arg);
+      if (it != ifOp->getResults().end()) {
+        size_t ifOpIdx = it.getIndex();
+        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+      }
+    }
+
+    SmallVector<Value> additionalUsedValues;
+    auto isValueUsedInsideIf = [&](Value val) {
+      return llvm::any_of(val.getUsers(), [&](Operation *user) {
+        return ifOp.getThenRegion().isAncestor(user->getParentRegion());
+      });
+    };
+
+    // Collect additional used values from before region.
+    for (Operation *it = ifOp->getPrevNode(); it != nullptr;
+         it = it->getPrevNode())
+      llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
+                    isValueUsedInsideIf);
+
+    llvm::copy_if(op.getBeforeArguments(),
+                  std::back_inserter(additionalUsedValues),
+                  isValueUsedInsideIf);
+
+    // Create new whileOp with additional used values as results.
+    auto additionalValueTypes = llvm::map_to_vector(
+        additionalUsedValues, [](Value val) { return val.getType(); });
+    size_t additionalValueSize = additionalUsedValues.size();
+    SmallVector<Type> newResultTypes(op.getResultTypes());
+    newResultTypes.append(additionalValueTypes);
+
+    auto newWhileOp =
+        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+    newWhileOp.getBefore().takeBody(op.getBefore());
+    newWhileOp.getAfter().takeBody(op.getAfter());
+    newWhileOp.getAfter().addArguments(
+        additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
+
+    SmallVector<Value> conditionArgs = conditionOp.getArgs();
+    llvm::append_range(conditionArgs, additionalUsedValues);
+
+    // Update conditionOp inside new whileOp before region.
+    rewriter.setInsertionPoint(conditionOp);
+    rewriter.replaceOpWithNewOp<scf::ConditionOp>(
+        conditionOp, conditionOp.getCondition(), conditionArgs);
+
+    // Replace uses of additional used values inside the ifOp then region with
+    // the whileOp after region arguments.
+    rewriter.replaceUsesWithIf(
+        additionalUsedValues,
+        newWhileOp.getAfterArguments().take_back(additionalValueSize),
+        [&](OpOperand &use) {
+          return ifOp.getThenRegion().isAncestor(
+              use.getOwner()->getParentRegion());
+        });
+
+    // Inline ifOp then region into new whileOp after region.
+    rewriter.eraseOp(ifOp.thenYield());
+    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+                               newWhileOp.getAfterBody()->begin());
+    rewriter.eraseOp(ifOp);
+    rewriter.replaceOp(op,
+                       newWhileOp->getResults().drop_back(additionalValueSize));
+    return success();
+  }
+};
+
 /// Replace uses of the condition within the do block with true, since otherwise
 /// the block would not be evaluated.
 ///
@@ -4258,7 +4389,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2bec63672e783..cfae3b34305de 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,6 +974,43 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 
 // -----
 
+// CHECK-LABEL: @while_move_if_down
+func.func @while_move_if_down() -> i32 {
+  %0 = scf.while () : () -> (i32) {
+    %additional_used_value = "test.get_some_value1" () : () -> (i32)
+    %else_value = "test.get_some_value2" () : () -> (i32)
+    %condition = "test.condition"() : () -> i1
+    %res = scf.if %condition -> (i32) {
+      "test.use1" (%additional_used_value) : (i32) -> ()
+      %then_value = "test.get_some_value3" () : () -> (i32)
+      scf.yield %then_value : i32
+    } else {
+      scf.yield %else_value : i32
+    }
+    scf.condition(%condition) %res : i32
+  } do {
+  ^bb0(%res_arg: i32):
+    "test.use2" (%res_arg) : (i32) -> ()
+    scf.yield
+  }
+  return %0 : i32
+}
+// CHECK-NEXT:      %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
+// CHECK-NEXT:        %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
+// CHECK-NEXT:        %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
+// CHECK-NEXT:        %[[VAL_2:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT:        scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
+// CHECK-NEXT:      } do {
+// CHECK-NEXT:      ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
+// CHECK-NEXT:        "test.use1"(%[[VAL_4]]) : (i32) -> ()
+// CHECK-NEXT:        %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
+// CHECK-NEXT:        "test.use2"(%[[VAL_5]]) : (i32) -> ()
+// CHECK-NEXT:        scf.yield
+// CHECK-NEXT:      }
+// CHECK-NEXT:      return %[[VAL_6:.*]]#0 : i32
+
+// -----
+
 // CHECK-LABEL: @while_cond_true
 func.func @while_cond_true() -> i1 {
   %0 = scf.while () : () -> i1 {

>From 19a042eab11e8ebe9251201c0fb09c3ab6c1ce7f Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 4 Nov 2025 11:33:48 +0800
Subject: [PATCH 2/2] Move the pattern into `populateUpliftWhileToForPattern`.

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 134 +-----------------
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 134 +++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir       |  37 -----
 mlir/test/Dialect/SCF/uplift-while.mlir       |  31 ++++
 4 files changed, 165 insertions(+), 171 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eaf17546e9281..9bd13f3236cfc 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3546,137 +3546,6 @@ LogicalResult scf::WhileOp::verify() {
 }
 
 namespace {
-/// Move an scf.if op that is directly before the scf.condition op in the while
-/// before region, and whose condition matches the condition of the
-/// scf.condition op, down into the while after region.
-///
-/// scf.while (..) : (...) -> ... {
-///  %additional_used_values = ...
-///  %cond = ...
-///  ...
-///  %res = scf.if %cond -> (...) {
-///    use(%additional_used_values)
-///    ... // then block
-///    scf.yield %then_value
-///  } else {
-///    scf.yield %else_value
-///  }
-///  scf.condition(%cond) %res, ...
-/// } do {
-/// ^bb0(%res_arg, ...):
-///    use(%res_arg)
-///    ...
-///
-/// becomes
-/// scf.while (..) : (...) -> ... {
-///  %additional_used_values = ...
-///  %cond = ...
-///  ...
-///  scf.condition(%cond) %else_value, ..., %additional_used_values
-/// } do {
-/// ^bb0(%res_arg ..., %additional_args): :
-///    use(%additional_args)
-///    ... // if then block
-///    use(%then_value)
-///    ...
-struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(WhileOp op,
-                                PatternRewriter &rewriter) const override {
-    auto conditionOp =
-        cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
-    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
-
-    // Check that the ifOp is directly before the conditionOp and that it
-    // matches the condition of the conditionOp. Also ensure that the ifOp has
-    // no else block with content, as that would complicate the transformation.
-    // TODO: support else blocks with content.
-    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
-        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
-      return failure();
-
-    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
-                                 *ifOp->user_begin() == conditionOp) &&
-                                    "ifOp has unexpected uses");
-
-    Location loc = op.getLoc();
-
-    // Replace uses of ifOp results in the conditionOp with the yielded values
-    // from the ifOp branches.
-    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
-      auto it = llvm::find(ifOp->getResults(), arg);
-      if (it != ifOp->getResults().end()) {
-        size_t ifOpIdx = it.getIndex();
-        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
-        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
-
-        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
-        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
-      }
-    }
-
-    SmallVector<Value> additionalUsedValues;
-    auto isValueUsedInsideIf = [&](Value val) {
-      return llvm::any_of(val.getUsers(), [&](Operation *user) {
-        return ifOp.getThenRegion().isAncestor(user->getParentRegion());
-      });
-    };
-
-    // Collect additional used values from before region.
-    for (Operation *it = ifOp->getPrevNode(); it != nullptr;
-         it = it->getPrevNode())
-      llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
-                    isValueUsedInsideIf);
-
-    llvm::copy_if(op.getBeforeArguments(),
-                  std::back_inserter(additionalUsedValues),
-                  isValueUsedInsideIf);
-
-    // Create new whileOp with additional used values as results.
-    auto additionalValueTypes = llvm::map_to_vector(
-        additionalUsedValues, [](Value val) { return val.getType(); });
-    size_t additionalValueSize = additionalUsedValues.size();
-    SmallVector<Type> newResultTypes(op.getResultTypes());
-    newResultTypes.append(additionalValueTypes);
-
-    auto newWhileOp =
-        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
-
-    newWhileOp.getBefore().takeBody(op.getBefore());
-    newWhileOp.getAfter().takeBody(op.getAfter());
-    newWhileOp.getAfter().addArguments(
-        additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
-
-    SmallVector<Value> conditionArgs = conditionOp.getArgs();
-    llvm::append_range(conditionArgs, additionalUsedValues);
-
-    // Update conditionOp inside new whileOp before region.
-    rewriter.setInsertionPoint(conditionOp);
-    rewriter.replaceOpWithNewOp<scf::ConditionOp>(
-        conditionOp, conditionOp.getCondition(), conditionArgs);
-
-    // Replace uses of additional used values inside the ifOp then region with
-    // the whileOp after region arguments.
-    rewriter.replaceUsesWithIf(
-        additionalUsedValues,
-        newWhileOp.getAfterArguments().take_back(additionalValueSize),
-        [&](OpOperand &use) {
-          return ifOp.getThenRegion().isAncestor(
-              use.getOwner()->getParentRegion());
-        });
-
-    // Inline ifOp then region into new whileOp after region.
-    rewriter.eraseOp(ifOp.thenYield());
-    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
-                               newWhileOp.getAfterBody()->begin());
-    rewriter.eraseOp(ifOp);
-    rewriter.replaceOp(op,
-                       newWhileOp->getResults().drop_back(additionalValueSize));
-    return success();
-  }
-};
-
 /// Replace uses of the condition within the do block with true, since otherwise
 /// the block would not be evaluated.
 ///
@@ -4389,8 +4258,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
-      context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ec1044aaa42ac..3021367e596f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -19,6 +19,137 @@
 using namespace mlir;
 
 namespace {
+/// Move a scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  %res = scf.if %cond -> (...) {
+///    use(%additional_used_values)
+///    ... // then block
+///    scf.yield %then_value
+///  } else {
+///    scf.yield %else_value
+///  }
+///  scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+///    use(%res_arg)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+///    use(%additional_args)
+///    ... // if then block
+///    use(%then_value)
+///    ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto conditionOp =
+        cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+    // Check that the ifOp is directly before the conditionOp and that it
+    // matches the condition of the conditionOp. Also ensure that the ifOp has
+    // no else block with content, as that would complicate the transformation.
+    // TODO: support else blocks with content.
+    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+      return failure();
+
+    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+                                 *ifOp->user_begin() == conditionOp) &&
+                                    "ifOp has unexpected uses");
+
+    Location loc = op.getLoc();
+
+    // Replace uses of ifOp results in the conditionOp with the yielded values
+    // from the ifOp branches.
+    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+      auto it = llvm::find(ifOp->getResults(), arg);
+      if (it != ifOp->getResults().end()) {
+        size_t ifOpIdx = it.getIndex();
+        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+      }
+    }
+
+    SmallVector<Value> additionalUsedValues;
+    auto isValueUsedInsideIf = [&](Value val) {
+      return llvm::any_of(val.getUsers(), [&](Operation *user) {
+        return ifOp.getThenRegion().isAncestor(user->getParentRegion());
+      });
+    };
+
+    // Collect additional used values from before region.
+    for (Operation *it = ifOp->getPrevNode(); it != nullptr;
+         it = it->getPrevNode())
+      llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
+                    isValueUsedInsideIf);
+
+    llvm::copy_if(op.getBeforeArguments(),
+                  std::back_inserter(additionalUsedValues),
+                  isValueUsedInsideIf);
+
+    // Create new whileOp with additional used values as results.
+    auto additionalValueTypes = llvm::map_to_vector(
+        additionalUsedValues, [](Value val) { return val.getType(); });
+    size_t additionalValueSize = additionalUsedValues.size();
+    SmallVector<Type> newResultTypes(op.getResultTypes());
+    newResultTypes.append(additionalValueTypes);
+
+    auto newWhileOp =
+        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+    newWhileOp.getBefore().takeBody(op.getBefore());
+    newWhileOp.getAfter().takeBody(op.getAfter());
+    newWhileOp.getAfter().addArguments(
+        additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
+
+    SmallVector<Value> conditionArgs = conditionOp.getArgs();
+    llvm::append_range(conditionArgs, additionalUsedValues);
+
+    // Update conditionOp inside new whileOp before region.
+    rewriter.setInsertionPoint(conditionOp);
+    rewriter.replaceOpWithNewOp<scf::ConditionOp>(
+        conditionOp, conditionOp.getCondition(), conditionArgs);
+
+    // Replace uses of additional used values inside the ifOp then region with
+    // the whileOp after region arguments.
+    rewriter.replaceUsesWithIf(
+        additionalUsedValues,
+        newWhileOp.getAfterArguments().take_back(additionalValueSize),
+        [&](OpOperand &use) {
+          return ifOp.getThenRegion().isAncestor(
+              use.getOwner()->getParentRegion());
+        });
+
+    // Inline ifOp then region into new whileOp after region.
+    rewriter.eraseOp(ifOp.thenYield());
+    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+                               newWhileOp.getAfterBody()->begin());
+    rewriter.eraseOp(ifOp);
+    rewriter.replaceOp(op,
+                       newWhileOp->getResults().drop_back(additionalValueSize));
+    return success();
+  }
+};
+
 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -267,5 +398,6 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 }
 
 void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
-  patterns.add<UpliftWhileOp>(patterns.getContext());
+  patterns.add<WhileMoveIfDown, UpliftWhileOp>(patterns.getContext());
+  scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index cfae3b34305de..2bec63672e783 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,43 +974,6 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 
 // -----
 
-// CHECK-LABEL: @while_move_if_down
-func.func @while_move_if_down() -> i32 {
-  %0 = scf.while () : () -> (i32) {
-    %additional_used_value = "test.get_some_value1" () : () -> (i32)
-    %else_value = "test.get_some_value2" () : () -> (i32)
-    %condition = "test.condition"() : () -> i1
-    %res = scf.if %condition -> (i32) {
-      "test.use1" (%additional_used_value) : (i32) -> ()
-      %then_value = "test.get_some_value3" () : () -> (i32)
-      scf.yield %then_value : i32
-    } else {
-      scf.yield %else_value : i32
-    }
-    scf.condition(%condition) %res : i32
-  } do {
-  ^bb0(%res_arg: i32):
-    "test.use2" (%res_arg) : (i32) -> ()
-    scf.yield
-  }
-  return %0 : i32
-}
-// CHECK-NEXT:      %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
-// CHECK-NEXT:        %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
-// CHECK-NEXT:        %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
-// CHECK-NEXT:        %[[VAL_2:.*]] = "test.condition"() : () -> i1
-// CHECK-NEXT:        scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
-// CHECK-NEXT:      } do {
-// CHECK-NEXT:      ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
-// CHECK-NEXT:        "test.use1"(%[[VAL_4]]) : (i32) -> ()
-// CHECK-NEXT:        %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
-// CHECK-NEXT:        "test.use2"(%[[VAL_5]]) : (i32) -> ()
-// CHECK-NEXT:        scf.yield
-// CHECK-NEXT:      }
-// CHECK-NEXT:      return %[[VAL_6:.*]]#0 : i32
-
-// -----
-
 // CHECK-LABEL: @while_cond_true
 func.func @while_cond_true() -> i1 {
   %0 = scf.while () : () -> i1 {
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index cbe2ce5076ad2..736112824c515 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,3 +185,34 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
 //       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
 //       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
 //       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
+  %c1 = arith.constant 1 : index
+  %1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
+    %2 = arith.cmpi slt, %iv, %upper : index
+    %3:2 = scf.if %2 -> (index, i32) {
+      %4 = "test.test"(%iter) : (i32) -> i32
+      %5 = arith.addi %iv, %c1 : index
+      scf.yield %5, %4 : index, i32
+    } else {
+      scf.yield %iv, %iter : index, i32
+    }
+    scf.condition(%2) %3#0, %3#1 : index, i32
+  } do {
+  ^bb0(%arg0: index, %arg1: i32):
+    scf.yield %arg0, %arg1 : index, i32
+  }
+  return %1#1 : i32
+}
+
+// CHECK-LABEL:   func.func @uplift_while(
+// CHECK-SAME:      %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : index
+// CHECK:           %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
+// CHECK:             %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
+// CHECK:             scf.yield %[[VAL_2]] : i32
+// CHECK:           }
+// CHECK:           return %[[FOR_0]] : i32
+// CHECK:         }



More information about the Mlir-commits mailing list