[Mlir-commits] [mlir] [MLIR][SCF] Canonicalize redundant scf.if from scf.while before region into after region (PR #169892)
Ming Yan
llvmlistbot at llvm.org
Fri Nov 28 06:34:35 PST 2025
https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/169892
>From 31f82075c3fd174d19c6b80b8aed721f0953369b Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Mon, 27 Oct 2025 16:35:03 +0800
Subject: [PATCH 1/2] [MLIR][SCF] Sink scf.if from scf.while before region into
after region.
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 125 +++++++++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 37 +++++++
2 files changed, 161 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 881e256a8797b..79dcf562db993 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -26,6 +26,7 @@
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -3687,6 +3688,127 @@ LogicalResult scf::WhileOp::verify() {
}
namespace {
+/// Move a scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// %res = scf.if %cond -> (...) {
+/// use(%additional_used_values)
+/// ... // then block
+/// scf.yield %then_value
+/// } else {
+/// scf.yield %else_value
+/// }
+/// scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+/// use(%res_arg)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+/// use(%additional_args)
+/// ... // if then block
+/// use(%then_value)
+/// ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto conditionOp =
+ cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+ // Check that the ifOp is directly before the conditionOp and that it
+ // matches the condition of the conditionOp. Also ensure that the ifOp has
+ // no else block with content, as that would complicate the transformation.
+ // TODO: support else blocks with content.
+ if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+ (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+ return failure();
+
+ assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+ *ifOp->user_begin() == conditionOp) &&
+ "ifOp has unexpected uses");
+
+ Location loc = op.getLoc();
+
+ // Replace uses of ifOp results in the conditionOp with the yielded values
+ // from the ifOp branches.
+ for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+ auto it = llvm::find(ifOp->getResults(), arg);
+ if (it != ifOp->getResults().end()) {
+ size_t ifOpIdx = it.getIndex();
+ Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+ Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+ rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+ rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+ }
+ }
+
+ // Collect additional used values from before region.
+ SetVector<Value> additionalUsedValues;
+ visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
+ if (op.getBefore().isAncestor(operand->get().getParentRegion()))
+ additionalUsedValues.insert(operand->get());
+ });
+
+ // Create new whileOp with additional used values as results.
+ auto additionalValueTypes = llvm::map_to_vector(
+ additionalUsedValues, [](Value val) { return val.getType(); });
+ size_t additionalValueSize = additionalUsedValues.size();
+ SmallVector<Type> newResultTypes(op.getResultTypes());
+ newResultTypes.append(additionalValueTypes);
+
+ auto newWhileOp =
+ scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+ newWhileOp.getBefore().takeBody(op.getBefore());
+ newWhileOp.getAfter().takeBody(op.getAfter());
+ newWhileOp.getAfter().addArguments(
+ additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
+
+ SmallVector<Value> conditionArgs = conditionOp.getArgs();
+ llvm::append_range(conditionArgs, additionalUsedValues);
+
+ // Update conditionOp inside new whileOp before region.
+ rewriter.setInsertionPoint(conditionOp);
+ rewriter.replaceOpWithNewOp<scf::ConditionOp>(
+ conditionOp, conditionOp.getCondition(), conditionArgs);
+
+ // Replace uses of additional used values inside the ifOp then region with
+ // the whileOp after region arguments.
+ rewriter.replaceUsesWithIf(
+ additionalUsedValues.takeVector(),
+ newWhileOp.getAfterArguments().take_back(additionalValueSize),
+ [&](OpOperand &use) {
+ return ifOp.getThenRegion().isAncestor(
+ use.getOwner()->getParentRegion());
+ });
+
+ // Inline ifOp then region into new whileOp after region.
+ rewriter.eraseOp(ifOp.thenYield());
+ rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+ newWhileOp.getAfterBody()->begin());
+ rewriter.eraseOp(ifOp);
+ rewriter.replaceOp(op,
+ newWhileOp->getResults().drop_back(additionalValueSize));
+ return success();
+ }
+};
+
/// Replace uses of the condition within the do block with true, since otherwise
/// the block would not be evaluated.
///
@@ -4399,7 +4521,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 084c3fc065de3..b02cbc07880b9 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,6 +974,43 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// -----
+// CHECK-LABEL: @while_move_if_down
+func.func @while_move_if_down() -> i32 {
+ %0 = scf.while () : () -> (i32) {
+ %additional_used_value = "test.get_some_value1" () : () -> (i32)
+ %else_value = "test.get_some_value2" () : () -> (i32)
+ %condition = "test.condition"() : () -> i1
+ %res = scf.if %condition -> (i32) {
+ "test.use1" (%additional_used_value) : (i32) -> ()
+ %then_value = "test.get_some_value3" () : () -> (i32)
+ scf.yield %then_value : i32
+ } else {
+ scf.yield %else_value : i32
+ }
+ scf.condition(%condition) %res : i32
+ } do {
+ ^bb0(%res_arg: i32):
+ "test.use2" (%res_arg) : (i32) -> ()
+ scf.yield
+ }
+ return %0 : i32
+}
+// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
+// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
+// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
+// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
+// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> ()
+// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
+// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32
+
+// -----
+
// CHECK-LABEL: @while_cond_true
func.func @while_cond_true() -> i1 {
%0 = scf.while () : () -> i1 {
>From 2c61dcd03ca4cd4e1d8f6c03481ba94774564820 Mon Sep 17 00:00:00 2001
From: Ming Yan <nexming7 at gmail.com>
Date: Fri, 28 Nov 2025 22:22:59 +0800
Subject: [PATCH 2/2] Simplify the code and update the tests.
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 40 ++++++++++++++-----------
mlir/test/Dialect/SCF/canonicalize.mlir | 34 +++++++++++----------
2 files changed, 42 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 79dcf562db993..bb07291036667 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3726,8 +3726,14 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
LogicalResult matchAndRewrite(scf::WhileOp op,
PatternRewriter &rewriter) const override {
- auto conditionOp =
- cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
+ auto conditionOp = op.getConditionOp();
+
+ // Only support ifOp right before the condition at the moment. Relaxing this
+ // would require to:
+ // - check that the body does not have side-effects conflicting with
+ // operations between the if and the condition.
+ // - check that results of the if operation are only used as arguments to
+ // the condition.
auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
// Check that the ifOp is directly before the conditionOp and that it
@@ -3759,13 +3765,14 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
}
// Collect additional used values from before region.
- SetVector<Value> additionalUsedValues;
+ SetVector<Value> additionalUsedValuesSet;
visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
- if (op.getBefore().isAncestor(operand->get().getParentRegion()))
- additionalUsedValues.insert(operand->get());
+ if (&op.getBefore() == operand->get().getParentRegion())
+ additionalUsedValuesSet.insert(operand->get());
});
// Create new whileOp with additional used values as results.
+ auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
auto additionalValueTypes = llvm::map_to_vector(
additionalUsedValues, [](Value val) { return val.getType(); });
size_t additionalValueSize = additionalUsedValues.size();
@@ -3775,23 +3782,22 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
auto newWhileOp =
scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
- newWhileOp.getBefore().takeBody(op.getBefore());
- newWhileOp.getAfter().takeBody(op.getAfter());
- newWhileOp.getAfter().addArguments(
- additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
-
- SmallVector<Value> conditionArgs = conditionOp.getArgs();
- llvm::append_range(conditionArgs, additionalUsedValues);
+ rewriter.modifyOpInPlace(newWhileOp, [&] {
+ newWhileOp.getBefore().takeBody(op.getBefore());
+ newWhileOp.getAfter().takeBody(op.getAfter());
+ newWhileOp.getAfter().addArguments(
+ additionalValueTypes,
+ SmallVector<Location>(additionalValueSize, loc));
+ });
- // Update conditionOp inside new whileOp before region.
- rewriter.setInsertionPoint(conditionOp);
- rewriter.replaceOpWithNewOp<scf::ConditionOp>(
- conditionOp, conditionOp.getCondition(), conditionArgs);
+ rewriter.modifyOpInPlace(conditionOp, [&] {
+ conditionOp.getArgsMutable().append(additionalUsedValues);
+ });
// Replace uses of additional used values inside the ifOp then region with
// the whileOp after region arguments.
rewriter.replaceUsesWithIf(
- additionalUsedValues.takeVector(),
+ additionalUsedValues,
newWhileOp.getAfterArguments().take_back(additionalValueSize),
[&](OpOperand &use) {
return ifOp.getThenRegion().isAncestor(
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index b02cbc07880b9..3b9e219403986 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -976,12 +976,14 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
// CHECK-LABEL: @while_move_if_down
func.func @while_move_if_down() -> i32 {
+ %defined_outside = "test.get_some_value" () : () -> (i32)
%0 = scf.while () : () -> (i32) {
%additional_used_value = "test.get_some_value1" () : () -> (i32)
%else_value = "test.get_some_value2" () : () -> (i32)
%condition = "test.condition"() : () -> i1
%res = scf.if %condition -> (i32) {
- "test.use1" (%additional_used_value) : (i32) -> ()
+ "test.use1" (%defined_outside) : (i32) -> ()
+ "test.use2" (%additional_used_value) : (i32) -> ()
%then_value = "test.get_some_value3" () : () -> (i32)
scf.yield %then_value : i32
} else {
@@ -990,24 +992,26 @@ func.func @while_move_if_down() -> i32 {
scf.condition(%condition) %res : i32
} do {
^bb0(%res_arg: i32):
- "test.use2" (%res_arg) : (i32) -> ()
+ "test.use3" (%res_arg) : (i32) -> ()
scf.yield
}
return %0 : i32
}
-// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
-// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
-// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
-// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
-// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
-// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> ()
-// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
-// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> ()
-// CHECK-NEXT: scf.yield
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32
+// CHECK-NEXT: %[[defined_outside:.*]] = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: %[[while_res:.*]]:2 = scf.while : () -> (i32, i32) {
+// CHECK-NEXT: %[[additional_used_value:.*]] = "test.get_some_value1"() : () -> i32
+// CHECK-NEXT: %[[else_value:.*]] = "test.get_some_value2"() : () -> i32
+// CHECK-NEXT: %[[condition:.*]] = "test.condition"() : () -> i1
+// CHECK-NEXT: scf.condition(%[[condition]]) %[[else_value]], %[[additional_used_value]] : i32, i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[res_arg:.*]]: i32, %[[additional_used_value_arg:.*]]: i32):
+// CHECK-NEXT: "test.use1"(%[[defined_outside]]) : (i32) -> ()
+// CHECK-NEXT: "test.use2"(%[[additional_used_value_arg]]) : (i32) -> ()
+// CHECK-NEXT: %[[then_value:.*]] = "test.get_some_value3"() : () -> i32
+// CHECK-NEXT: "test.use3"(%[[then_value]]) : (i32) -> ()
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[while_res:.*]]#0 : i32
// -----
More information about the Mlir-commits
mailing list