[Mlir-commits] [mlir] 25d027b - [MLIR][SCF] Sink scf.if from scf.while before region into after region in scf-uplift-while-to-for (#165216)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 27 07:30:09 PST 2025
Author: Ming Yan
Date: 2025-11-27T23:30:05+08:00
New Revision: 25d027b8ab3acd65b58fce278f4173b431326934
URL: https://github.com/llvm/llvm-project/commit/25d027b8ab3acd65b58fce278f4173b431326934
DIFF: https://github.com/llvm/llvm-project/commit/25d027b8ab3acd65b58fce278f4173b431326934.diff
LOG: [MLIR][SCF] Sink scf.if from scf.while before region into after region in scf-uplift-while-to-for (#165216)
When a `scf.if` directly precedes an `scf.condition` in the before
region of an `scf.while` and both share the same condition, move the if
into the after region of the loop. This helps simplify the control flow
to enable uplifting `scf.while` to `scf.for`.
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
mlir/test/Dialect/SCF/uplift-while.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ec1044aaa42ac..9f242f9e62b8e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -19,6 +19,83 @@
using namespace mlir;
namespace {
+/// Move an scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (%init) : (...) -> ... {
+/// %cond = ...
+/// %res = scf.if %cond -> (...) {
+/// use1(%init)
+/// %then_val = ...
+/// ... // then block
+/// scf.yield %then_val
+/// } else {
+/// scf.yield %init
+/// }
+/// scf.condition(%cond) %res
+/// } do {
+/// ^bb0(%arg):
+/// use2(%arg)
+/// ...
+///
+/// becomes
+/// scf.while (%init) : (...) -> ... {
+/// %cond = ...
+/// scf.condition(%cond) %init
+/// } do {
+/// ^bb0(%arg): :
+/// use1(%arg)
+/// ... // if then block
+/// %then_val = ...
+/// use2(%then_val)
+/// ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp op,
+ PatternRewriter &rewriter) const override {
+ // Check that the first opeation produces one result and that result must
+ // have exactly two uses (these two uses come from the `scf.if` and
+ // `scf.condition` operations).
+ Operation &condOp = op.getBeforeBody()->front();
+ if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2))
+ return failure();
+
+ Value condVal = condOp.getResult(0);
+ auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode());
+ if (!ifOp || ifOp.getCondition() != condVal)
+ return failure();
+
+ auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode());
+ if (!term || term.getCondition() != condVal)
+ return failure();
+
+ // Check that if results and else yield operands match the scf.condition op
+ // arguments and while before arguments respectively.
+ if (!llvm::equal(ifOp->getResults(), term.getArgs()) ||
+ !llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments()))
+ return failure();
+
+ // Update uses and move the if op into the after region.
+ rewriter.replaceAllUsesWith(op.getAfterArguments(),
+ ifOp.thenYield()->getOperands());
+ rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(),
+ [&](OpOperand &use) {
+ return ifOp.getThenRegion().isAncestor(
+ use.getOwner()->getParentRegion());
+ });
+ rewriter.modifyOpInPlace(
+ term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); });
+
+ rewriter.eraseOp(ifOp.thenYield());
+ rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(),
+ op.getAfterBody()->begin());
+ rewriter.eraseOp(ifOp);
+ return success();
+ }
+};
+
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
using OpRewritePattern::OpRewritePattern;
@@ -267,5 +344,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
}
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
- patterns.add<UpliftWhileOp>(patterns.getContext());
+ patterns.add<UpliftWhileOp, WhileMoveIfDown>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index cbe2ce5076ad2..736112824c515 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,3 +185,34 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
+ %c1 = arith.constant 1 : index
+ %1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
+ %2 = arith.cmpi slt, %iv, %upper : index
+ %3:2 = scf.if %2 -> (index, i32) {
+ %4 = "test.test"(%iter) : (i32) -> i32
+ %5 = arith.addi %iv, %c1 : index
+ scf.yield %5, %4 : index, i32
+ } else {
+ scf.yield %iv, %iter : index, i32
+ }
+ scf.condition(%2) %3#0, %3#1 : index, i32
+ } do {
+ ^bb0(%arg0: index, %arg1: i32):
+ scf.yield %arg0, %arg1 : index, i32
+ }
+ return %1#1 : i32
+}
+
+// CHECK-LABEL: func.func @uplift_while(
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
+// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
+// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
+// CHECK: scf.yield %[[VAL_2]] : i32
+// CHECK: }
+// CHECK: return %[[FOR_0]] : i32
+// CHECK: }
More information about the Mlir-commits
mailing list