[Mlir-commits] [mlir] [MLIR][SCF] Sink scf.if from scf.while before region into after region. (PR #165216)
Ming Yan
llvmlistbot at llvm.org
Mon Nov 3 19:49:40 PST 2025
https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/165216
>From ecffe33bf9ddd95db18922ba60973dc5a62c81eb Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Mon, 27 Oct 2025 16:35:03 +0800
Subject: [PATCH 1/2] [MLIR][SCF] Sink scf.if from scf.while before region into
after region.
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 134 +++++++++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 37 +++++++
2 files changed, 170 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9bd13f3236cfc..eaf17546e9281 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3546,6 +3546,137 @@ LogicalResult scf::WhileOp::verify() {
}
namespace {
+/// Move an scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// %res = scf.if %cond -> (...) {
+/// use(%additional_used_values)
+/// ... // then block
+/// scf.yield %then_value
+/// } else {
+/// scf.yield %else_value
+/// }
+/// scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+/// use(%res_arg)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+/// use(%additional_args)
+/// ... // if then block
+/// use(%then_value)
+/// ...
+struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto conditionOp =
+ cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+ // Check that the ifOp is directly before the conditionOp and that it
+ // matches the condition of the conditionOp. Also ensure that the ifOp has
+ // no else block with content, as that would complicate the transformation.
+ // TODO: support else blocks with content.
+ if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+ (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+ return failure();
+
+ assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+ *ifOp->user_begin() == conditionOp) &&
+ "ifOp has unexpected uses");
+
+ Location loc = op.getLoc();
+
+ // Replace uses of ifOp results in the conditionOp with the yielded values
+ // from the ifOp branches.
+ for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+ auto it = llvm::find(ifOp->getResults(), arg);
+ if (it != ifOp->getResults().end()) {
+ size_t ifOpIdx = it.getIndex();
+ Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+ Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+ rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+ rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+ }
+ }
+
+ SmallVector<Value> additionalUsedValues;
+ auto isValueUsedInsideIf = [&](Value val) {
+ return llvm::any_of(val.getUsers(), [&](Operation *user) {
+ return ifOp.getThenRegion().isAncestor(user->getParentRegion());
+ });
+ };
+
+ // Collect additional used values from before region.
+ for (Operation *it = ifOp->getPrevNode(); it != nullptr;
+ it = it->getPrevNode())
+ llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
+ isValueUsedInsideIf);
+
+ llvm::copy_if(op.getBeforeArguments(),
+ std::back_inserter(additionalUsedValues),
+ isValueUsedInsideIf);
+
+ // Create new whileOp with additional used values as results.
+ auto additionalValueTypes = llvm::map_to_vector(
+ additionalUsedValues, [](Value val) { return val.getType(); });
+ size_t additionalValueSize = additionalUsedValues.size();
+ SmallVector<Type> newResultTypes(op.getResultTypes());
+ newResultTypes.append(additionalValueTypes);
+
+ auto newWhileOp =
+ scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+ newWhileOp.getBefore().takeBody(op.getBefore());
+ newWhileOp.getAfter().takeBody(op.getAfter());
+ newWhileOp.getAfter().addArguments(
+ additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
+
+ SmallVector<Value> conditionArgs = conditionOp.getArgs();
+ llvm::append_range(conditionArgs, additionalUsedValues);
+
+ // Update conditionOp inside new whileOp before region.
+ rewriter.setInsertionPoint(conditionOp);
+ rewriter.replaceOpWithNewOp<scf::ConditionOp>(
+ conditionOp, conditionOp.getCondition(), conditionArgs);
+
+ // Replace uses of additional used values inside the ifOp then region with
+ // the whileOp after region arguments.
+ rewriter.replaceUsesWithIf(
+ additionalUsedValues,
+ newWhileOp.getAfterArguments().take_back(additionalValueSize),
+ [&](OpOperand &use) {
+ return ifOp.getThenRegion().isAncestor(
+ use.getOwner()->getParentRegion());
+ });
+
+ // Inline ifOp then region into new whileOp after region.
+ rewriter.eraseOp(ifOp.thenYield());
+ rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+ newWhileOp.getAfterBody()->begin());
+ rewriter.eraseOp(ifOp);
+ rewriter.replaceOp(op,
+ newWhileOp->getResults().drop_back(additionalValueSize));
+ return success();
+ }
+};
+
/// Replace uses of the condition within the do block with true, since otherwise
/// the block would not be evaluated.
///
@@ -4258,7 +4389,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2bec63672e783..cfae3b34305de 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,6 +974,43 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// -----
+// CHECK-LABEL: @while_move_if_down
+func.func @while_move_if_down() -> i32 {
+ %0 = scf.while () : () -> (i32) {
+ %additional_used_value = "test.get_some_value1" () : () -> (i32)
+ %else_value = "test.get_some_value2" () : () -> (i32)
+ %condition = "test.condition"() : () -> i1
+ %res = scf.if %condition -> (i32) {
+ "test.use1" (%additional_used_value) : (i32) -> ()
+ %then_value = "test.get_some_value3" () : () -> (i32)
+ scf.yield %then_value : i32
+ } else {
+ scf.yield %else_value : i32
+ }
+ scf.condition(%condition) %res : i32
+ } do {
+ ^bb0(%res_arg: i32):
+ "test.use2" (%res_arg) : (i32) -> ()
+ scf.yield
+ }
+ return %0 : i32
+}
+// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
+// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
+// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
+// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
+// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> ()
+// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
+// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32
+
+// -----
+
// CHECK-LABEL: @while_cond_true
func.func @while_cond_true() -> i1 {
%0 = scf.while () : () -> i1 {
>From 19a042eab11e8ebe9251201c0fb09c3ab6c1ce7f Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Tue, 4 Nov 2025 11:33:48 +0800
Subject: [PATCH 2/2] Move the pattern into `populateUpliftWhileToForPattern`.
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 134 +-----------------
.../SCF/Transforms/UpliftWhileToFor.cpp | 134 +++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 37 -----
mlir/test/Dialect/SCF/uplift-while.mlir | 31 ++++
4 files changed, 165 insertions(+), 171 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eaf17546e9281..9bd13f3236cfc 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3546,137 +3546,6 @@ LogicalResult scf::WhileOp::verify() {
}
namespace {
-/// Move an scf.if op that is directly before the scf.condition op in the while
-/// before region, and whose condition matches the condition of the
-/// scf.condition op, down into the while after region.
-///
-/// scf.while (..) : (...) -> ... {
-/// %additional_used_values = ...
-/// %cond = ...
-/// ...
-/// %res = scf.if %cond -> (...) {
-/// use(%additional_used_values)
-/// ... // then block
-/// scf.yield %then_value
-/// } else {
-/// scf.yield %else_value
-/// }
-/// scf.condition(%cond) %res, ...
-/// } do {
-/// ^bb0(%res_arg, ...):
-/// use(%res_arg)
-/// ...
-///
-/// becomes
-/// scf.while (..) : (...) -> ... {
-/// %additional_used_values = ...
-/// %cond = ...
-/// ...
-/// scf.condition(%cond) %else_value, ..., %additional_used_values
-/// } do {
-/// ^bb0(%res_arg ..., %additional_args): :
-/// use(%additional_args)
-/// ... // if then block
-/// use(%then_value)
-/// ...
-struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- auto conditionOp =
- cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
- auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
-
- // Check that the ifOp is directly before the conditionOp and that it
- // matches the condition of the conditionOp. Also ensure that the ifOp has
- // no else block with content, as that would complicate the transformation.
- // TODO: support else blocks with content.
- if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
- (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
- return failure();
-
- assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
- *ifOp->user_begin() == conditionOp) &&
- "ifOp has unexpected uses");
-
- Location loc = op.getLoc();
-
- // Replace uses of ifOp results in the conditionOp with the yielded values
- // from the ifOp branches.
- for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
- auto it = llvm::find(ifOp->getResults(), arg);
- if (it != ifOp->getResults().end()) {
- size_t ifOpIdx = it.getIndex();
- Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
- Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
-
- rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
- rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
- }
- }
-
- SmallVector<Value> additionalUsedValues;
- auto isValueUsedInsideIf = [&](Value val) {
- return llvm::any_of(val.getUsers(), [&](Operation *user) {
- return ifOp.getThenRegion().isAncestor(user->getParentRegion());
- });
- };
-
- // Collect additional used values from before region.
- for (Operation *it = ifOp->getPrevNode(); it != nullptr;
- it = it->getPrevNode())
- llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
- isValueUsedInsideIf);
-
- llvm::copy_if(op.getBeforeArguments(),
- std::back_inserter(additionalUsedValues),
- isValueUsedInsideIf);
-
- // Create new whileOp with additional used values as results.
- auto additionalValueTypes = llvm::map_to_vector(
- additionalUsedValues, [](Value val) { return val.getType(); });
- size_t additionalValueSize = additionalUsedValues.size();
- SmallVector<Type> newResultTypes(op.getResultTypes());
- newResultTypes.append(additionalValueTypes);
-
- auto newWhileOp =
- scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
-
- newWhileOp.getBefore().takeBody(op.getBefore());
- newWhileOp.getAfter().takeBody(op.getAfter());
- newWhileOp.getAfter().addArguments(
- additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
-
- SmallVector<Value> conditionArgs = conditionOp.getArgs();
- llvm::append_range(conditionArgs, additionalUsedValues);
-
- // Update conditionOp inside new whileOp before region.
- rewriter.setInsertionPoint(conditionOp);
- rewriter.replaceOpWithNewOp<scf::ConditionOp>(
- conditionOp, conditionOp.getCondition(), conditionArgs);
-
- // Replace uses of additional used values inside the ifOp then region with
- // the whileOp after region arguments.
- rewriter.replaceUsesWithIf(
- additionalUsedValues,
- newWhileOp.getAfterArguments().take_back(additionalValueSize),
- [&](OpOperand &use) {
- return ifOp.getThenRegion().isAncestor(
- use.getOwner()->getParentRegion());
- });
-
- // Inline ifOp then region into new whileOp after region.
- rewriter.eraseOp(ifOp.thenYield());
- rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
- newWhileOp.getAfterBody()->begin());
- rewriter.eraseOp(ifOp);
- rewriter.replaceOp(op,
- newWhileOp->getResults().drop_back(additionalValueSize));
- return success();
- }
-};
-
/// Replace uses of the condition within the do block with true, since otherwise
/// the block would not be evaluated.
///
@@ -4389,8 +4258,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
- context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ec1044aaa42ac..3021367e596f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -19,6 +19,137 @@
using namespace mlir;
namespace {
+/// Move a scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// %res = scf.if %cond -> (...) {
+/// use(%additional_used_values)
+/// ... // then block
+/// scf.yield %then_value
+/// } else {
+/// scf.yield %else_value
+/// }
+/// scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+/// use(%res_arg)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+/// use(%additional_args)
+/// ... // if then block
+/// use(%then_value)
+/// ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto conditionOp =
+ cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+ // Check that the ifOp is directly before the conditionOp and that it
+ // matches the condition of the conditionOp. Also ensure that the ifOp has
+ // no else block with content, as that would complicate the transformation.
+ // TODO: support else blocks with content.
+ if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+ (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+ return failure();
+
+ assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+ *ifOp->user_begin() == conditionOp) &&
+ "ifOp has unexpected uses");
+
+ Location loc = op.getLoc();
+
+ // Replace uses of ifOp results in the conditionOp with the yielded values
+ // from the ifOp branches.
+ for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+ auto it = llvm::find(ifOp->getResults(), arg);
+ if (it != ifOp->getResults().end()) {
+ size_t ifOpIdx = it.getIndex();
+ Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+ Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+ rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+ rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+ }
+ }
+
+ SmallVector<Value> additionalUsedValues;
+ auto isValueUsedInsideIf = [&](Value val) {
+ return llvm::any_of(val.getUsers(), [&](Operation *user) {
+ return ifOp.getThenRegion().isAncestor(user->getParentRegion());
+ });
+ };
+
+ // Collect additional used values from before region.
+ for (Operation *it = ifOp->getPrevNode(); it != nullptr;
+ it = it->getPrevNode())
+ llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
+ isValueUsedInsideIf);
+
+ llvm::copy_if(op.getBeforeArguments(),
+ std::back_inserter(additionalUsedValues),
+ isValueUsedInsideIf);
+
+ // Create new whileOp with additional used values as results.
+ auto additionalValueTypes = llvm::map_to_vector(
+ additionalUsedValues, [](Value val) { return val.getType(); });
+ size_t additionalValueSize = additionalUsedValues.size();
+ SmallVector<Type> newResultTypes(op.getResultTypes());
+ newResultTypes.append(additionalValueTypes);
+
+ auto newWhileOp =
+ scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+ newWhileOp.getBefore().takeBody(op.getBefore());
+ newWhileOp.getAfter().takeBody(op.getAfter());
+ newWhileOp.getAfter().addArguments(
+ additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
+
+ SmallVector<Value> conditionArgs = conditionOp.getArgs();
+ llvm::append_range(conditionArgs, additionalUsedValues);
+
+ // Update conditionOp inside new whileOp before region.
+ rewriter.setInsertionPoint(conditionOp);
+ rewriter.replaceOpWithNewOp<scf::ConditionOp>(
+ conditionOp, conditionOp.getCondition(), conditionArgs);
+
+ // Replace uses of additional used values inside the ifOp then region with
+ // the whileOp after region arguments.
+ rewriter.replaceUsesWithIf(
+ additionalUsedValues,
+ newWhileOp.getAfterArguments().take_back(additionalValueSize),
+ [&](OpOperand &use) {
+ return ifOp.getThenRegion().isAncestor(
+ use.getOwner()->getParentRegion());
+ });
+
+ // Inline ifOp then region into new whileOp after region.
+ rewriter.eraseOp(ifOp.thenYield());
+ rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+ newWhileOp.getAfterBody()->begin());
+ rewriter.eraseOp(ifOp);
+ rewriter.replaceOp(op,
+ newWhileOp->getResults().drop_back(additionalValueSize));
+ return success();
+ }
+};
+
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
using OpRewritePattern::OpRewritePattern;
@@ -267,5 +398,6 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
}
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
- patterns.add<UpliftWhileOp>(patterns.getContext());
+ patterns.add<WhileMoveIfDown, UpliftWhileOp>(patterns.getContext());
+ scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext());
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index cfae3b34305de..2bec63672e783 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,43 +974,6 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// -----
-// CHECK-LABEL: @while_move_if_down
-func.func @while_move_if_down() -> i32 {
- %0 = scf.while () : () -> (i32) {
- %additional_used_value = "test.get_some_value1" () : () -> (i32)
- %else_value = "test.get_some_value2" () : () -> (i32)
- %condition = "test.condition"() : () -> i1
- %res = scf.if %condition -> (i32) {
- "test.use1" (%additional_used_value) : (i32) -> ()
- %then_value = "test.get_some_value3" () : () -> (i32)
- scf.yield %then_value : i32
- } else {
- scf.yield %else_value : i32
- }
- scf.condition(%condition) %res : i32
- } do {
- ^bb0(%res_arg: i32):
- "test.use2" (%res_arg) : (i32) -> ()
- scf.yield
- }
- return %0 : i32
-}
-// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
-// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
-// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
-// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
-// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
-// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> ()
-// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
-// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> ()
-// CHECK-NEXT: scf.yield
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32
-
-// -----
-
// CHECK-LABEL: @while_cond_true
func.func @while_cond_true() -> i1 {
%0 = scf.while () : () -> i1 {
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index cbe2ce5076ad2..736112824c515 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,3 +185,34 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
+ %c1 = arith.constant 1 : index
+ %1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
+ %2 = arith.cmpi slt, %iv, %upper : index
+ %3:2 = scf.if %2 -> (index, i32) {
+ %4 = "test.test"(%iter) : (i32) -> i32
+ %5 = arith.addi %iv, %c1 : index
+ scf.yield %5, %4 : index, i32
+ } else {
+ scf.yield %iv, %iter : index, i32
+ }
+ scf.condition(%2) %3#0, %3#1 : index, i32
+ } do {
+ ^bb0(%arg0: index, %arg1: i32):
+ scf.yield %arg0, %arg1 : index, i32
+ }
+ return %1#1 : i32
+}
+
+// CHECK-LABEL: func.func @uplift_while(
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
+// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
+// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
+// CHECK: scf.yield %[[VAL_2]] : i32
+// CHECK: }
+// CHECK: return %[[FOR_0]] : i32
+// CHECK: }
More information about the Mlir-commits
mailing list