[Mlir-commits] [mlir] 8ceeba8 - [MLIR][SCF] Canonicalize redundant scf.if from scf.while before region into after region (#169892)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 1 02:54:25 PST 2025


Author: Ming Yan
Date: 2025-12-01T18:54:21+08:00
New Revision: 8ceeba83812d551423a9e50f600cc77ea4718ca2

URL: https://github.com/llvm/llvm-project/commit/8ceeba83812d551423a9e50f600cc77ea4718ca2
DIFF: https://github.com/llvm/llvm-project/commit/8ceeba83812d551423a9e50f600cc77ea4718ca2.diff

LOG: [MLIR][SCF] Canonicalize redundant scf.if from scf.while before region into after region (#169892)

When a `scf.if` directly precedes a `scf.condition` in the before region
of a `scf.while` and both share the same condition, move the if into the
after region of the loop. This helps simplify the control flow to enable
uplifting `scf.while` to `scf.for`.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 881e256a8797b..bb07291036667 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() {
 }
 
 namespace {
+/// Move a scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  %res = scf.if %cond -> (...) {
+///    use(%additional_used_values)
+///    ... // then block
+///    scf.yield %then_value
+///  } else {
+///    scf.yield %else_value
+///  }
+///  scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+///    use(%res_arg)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+///  %additional_used_values = ...
+///  %cond = ...
+///  ...
+///  scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+///    use(%additional_args)
+///    ... // if then block
+///    use(%then_value)
+///    ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    auto conditionOp = op.getConditionOp();
+
+    // Only support ifOp right before the condition at the moment. Relaxing this
+    // would require to:
+    // - check that the body does not have side-effects conflicting with
+    //    operations between the if and the condition.
+    // - check that results of the if operation are only used as arguments to
+    //    the condition.
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+    // Check that the ifOp is directly before the conditionOp and that it
+    // matches the condition of the conditionOp. Also ensure that the ifOp has
+    // no else block with content, as that would complicate the transformation.
+    // TODO: support else blocks with content.
+    if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+        (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+      return failure();
+
+    assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+                                 *ifOp->user_begin() == conditionOp) &&
+                                    "ifOp has unexpected uses");
+
+    Location loc = op.getLoc();
+
+    // Replace uses of ifOp results in the conditionOp with the yielded values
+    // from the ifOp branches.
+    for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+      auto it = llvm::find(ifOp->getResults(), arg);
+      if (it != ifOp->getResults().end()) {
+        size_t ifOpIdx = it.getIndex();
+        Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+        Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+        rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+        rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+      }
+    }
+
+    // Collect additional used values from before region.
+    SetVector<Value> additionalUsedValuesSet;
+    visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
+      if (&op.getBefore() == operand->get().getParentRegion())
+        additionalUsedValuesSet.insert(operand->get());
+    });
+
+    // Create new whileOp with additional used values as results.
+    auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
+    auto additionalValueTypes = llvm::map_to_vector(
+        additionalUsedValues, [](Value val) { return val.getType(); });
+    size_t additionalValueSize = additionalUsedValues.size();
+    SmallVector<Type> newResultTypes(op.getResultTypes());
+    newResultTypes.append(additionalValueTypes);
+
+    auto newWhileOp =
+        scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+    rewriter.modifyOpInPlace(newWhileOp, [&] {
+      newWhileOp.getBefore().takeBody(op.getBefore());
+      newWhileOp.getAfter().takeBody(op.getAfter());
+      newWhileOp.getAfter().addArguments(
+          additionalValueTypes,
+          SmallVector<Location>(additionalValueSize, loc));
+    });
+
+    rewriter.modifyOpInPlace(conditionOp, [&] {
+      conditionOp.getArgsMutable().append(additionalUsedValues);
+    });
+
+    // Replace uses of additional used values inside the ifOp then region with
+    // the whileOp after region arguments.
+    rewriter.replaceUsesWithIf(
+        additionalUsedValues,
+        newWhileOp.getAfterArguments().take_back(additionalValueSize),
+        [&](OpOperand &use) {
+          return ifOp.getThenRegion().isAncestor(
+              use.getOwner()->getParentRegion());
+        });
+
+    // Inline ifOp then region into new whileOp after region.
+    rewriter.eraseOp(ifOp.thenYield());
+    rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+                               newWhileOp.getAfterBody()->begin());
+    rewriter.eraseOp(ifOp);
+    rewriter.replaceOp(op,
+                       newWhileOp->getResults().drop_back(additionalValueSize));
+    return success();
+  }
+};
+
 /// Replace uses of the condition within the do block with true, since otherwise
 /// the block would not be evaluated.
 ///
@@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+              WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 084c3fc065de3..ac590fc0c47b9 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -974,6 +974,56 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
 
 // -----
 
+// CHECK-LABEL: @while_move_if_down
+func.func @while_move_if_down() -> i32 {
+  %defined_outside = "test.get_some_value0" () : () -> (i32)
+  %0 = scf.while () : () -> (i32) {
+    %used_value = "test.get_some_value1" () : () -> (i32)
+    %used_by_subregion = "test.get_some_value2" () : () -> (i32)
+    %else_value = "test.get_some_value3" () : () -> (i32)
+    %condition = "test.condition"() : () -> i1
+    %res = scf.if %condition -> (i32) {
+      "test.use0" (%defined_outside) : (i32) -> ()
+      "test.use1" (%used_value) : (i32) -> ()
+      test.alloca_scope_region {
+        "test.use2" (%used_by_subregion) : (i32) -> ()
+      }
+      %then_value = "test.get_some_value4" () : () -> (i32)
+      scf.yield %then_value : i32
+    } else {
+      scf.yield %else_value : i32
+    }
+    scf.condition(%condition) %res : i32
+  } do {
+  ^bb0(%res_arg: i32):
+    "test.use3" (%res_arg) : (i32) -> ()
+    scf.yield
+  }
+  return %0 : i32
+}
+// CHECK:           %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32
+// CHECK:           %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) {
+// CHECK:             %[[used_value:.*]] = "test.get_some_value1"() : () -> i32
+// CHECK:             %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32
+// CHECK:             %[[else_value:.*]] = "test.get_some_value3"() : () -> i32
+// CHECK:             %[[condition:.*]] = "test.condition"() : () -> i1
+// CHECK:             scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32):
+// CHECK:             "test.use0"(%[[defined_outside]]) : (i32) -> ()
+// CHECK:             "test.use1"(%[[used_value_arg]]) : (i32) -> ()
+// CHECK:             test.alloca_scope_region {
+// CHECK:               "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> ()
+// CHECK:             }
+// CHECK:             %[[then_value:.*]] = "test.get_some_value4"() : () -> i32
+// CHECK:             "test.use3"(%[[then_value]]) : (i32) -> ()
+// CHECK:             scf.yield
+// CHECK:           }
+// CHECK:           return %[[WHILE_RES]]#0 : i32
+// CHECK:         }
+
+// -----
+
 // CHECK-LABEL: @while_cond_true
 func.func @while_cond_true() -> i1 {
   %0 = scf.while () : () -> i1 {


        


More information about the Mlir-commits mailing list