[Mlir-commits] [mlir] [mlir] move if-condition propagation to a standalone pass (PR #150278)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 10:34:19 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist.

It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.

---
Full diff: https://github.com/llvm/llvm-project/pull/150278.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+6) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-62) 
- (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp (+96) 
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-35) 
- (added) mlir/test/Dialect/SCF/if-cond-prop.mlir (+34) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..ca2510bb53af9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
   let constructor = "mlir::createForLoopSpecializationPass()";
 }
 
+def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
+  let summary = "Replace usages of if condition with true/false constants in "
+                "the conditional regions";
+  let dependentDialects = ["arith::ArithDialect"];
+}
+
 def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
   let summary = "Fuse adjacent parallel loops";
   let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 00c31a1500e17..6cb61900928d6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2412,65 +2412,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
   }
 };
 
-/// Allow the true region of an if to assume the condition is true
-/// and vice versa. For example:
-///
-///   scf.if %cmp {
-///      print(%cmp)
-///   }
-///
-///  becomes
-///
-///   scf.if %cmp {
-///      print(true)
-///   }
-///
-struct ConditionPropagation : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    // Early exit if the condition is constant since replacing a constant
-    // in the body with another constant isn't a simplification.
-    if (matchPattern(op.getCondition(), m_Constant()))
-      return failure();
-
-    bool changed = false;
-    mlir::Type i1Ty = rewriter.getI1Type();
-
-    // These variables serve to prevent creating duplicate constants
-    // and hold constant true or false values.
-    Value constantTrue = nullptr;
-    Value constantFalse = nullptr;
-
-    for (OpOperand &use :
-         llvm::make_early_inc_range(op.getCondition().getUses())) {
-      if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
-        changed = true;
-
-        if (!constantTrue)
-          constantTrue = rewriter.create<arith::ConstantOp>(
-              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
-
-        rewriter.modifyOpInPlace(use.getOwner(),
-                                 [&]() { use.set(constantTrue); });
-      } else if (op.getElseRegion().isAncestor(
-                     use.getOwner()->getParentRegion())) {
-        changed = true;
-
-        if (!constantFalse)
-          constantFalse = rewriter.create<arith::ConstantOp>(
-              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
-
-        rewriter.modifyOpInPlace(use.getOwner(),
-                                 [&]() { use.set(constantFalse); });
-      }
-    }
-
-    return success(changed);
-  }
-};
-
 /// Remove any statements from an if that are equivalent to the condition
 /// or its negation. For example:
 ///
@@ -2852,9 +2793,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
 
 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
-  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
-              ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, RemoveUnusedResults,
+  results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
+              RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
               ReplaceIfYieldWithConditionOrValue>(context);
 }
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..6d3bafbbc90e4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ForallToFor.cpp
   ForallToParallel.cpp
   ForToWhile.cpp
+  IfConditionPropagation.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
   LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
new file mode 100644
index 0000000000000..be8d0e805a7a4
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
@@ -0,0 +1,96 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+                                      llvm::SmallPtrSet<Region *, 8> &visited) {
+  if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+    llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+    // Visit the "then" region, collect children.
+    for (Block &block : scfIf.getThenRegion()) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, thenChildren);
+      }
+    }
+
+    // Visit the "else" region, collect children.
+    for (Block &block : scfIf.getElseRegion()) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, elseChildren);
+      }
+    }
+
+    // Update uses to point to constants instead.
+    OpBuilder builder(scfIf);
+    Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+                                                   builder.getBoolAttr(true));
+    Value falseValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+                                                    builder.getBoolAttr(false));
+
+    for (OpOperand &use : scfIf.getCondition().getUses()) {
+      if (thenChildren.contains(use.getOwner()->getParentRegion()))
+        use.set(trueValue);
+      else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+        use.set(falseValue);
+    }
+    if (trueValue.getUses().empty())
+      trueValue.getDefiningOp()->erase();
+    if (falseValue.getUses().empty())
+      falseValue.getDefiningOp()->erase();
+
+    // Append the two lists of children and return them.
+    visited.insert_range(thenChildren);
+    visited.insert_range(elseChildren);
+    return;
+  }
+
+  for (Region &region : root->getRegions()) {
+    for (Block &block : region) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, visited);
+      }
+    }
+  }
+}
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions
+static void propagateIfConditions(Operation *root) {
+  llvm::SmallPtrSet<Region *, 8> visited;
+  propagateIfConditionsImpl(root, visited);
+}
+
+namespace {
+/// Pass entrypoint.
+struct SCFIfConditionPropagationPass
+    : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
+  void runOnOperation() override { propagateIfConditions(getOperation()); }
+};
+} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..12d30e17f4a8f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
 
 // -----
 
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
-  %res = scf.if %arg0 -> index {
-    %res1 = scf.if %arg0 -> index {
-      %v1 = "test.get_some_value1"() : () -> index
-      scf.yield %v1 : index
-    } else {
-      %v2 = "test.get_some_value2"() : () -> index
-      scf.yield %v2 : index
-    }
-    scf.yield %res1 : index
-  } else {
-    %res2 = scf.if %arg0 -> index {
-      %v3 = "test.get_some_value3"() : () -> index
-      scf.yield %v3 : index
-    } else {
-      %v4 = "test.get_some_value4"() : () -> index
-      scf.yield %v4 : index
-    }
-    scf.yield %res2 : index
-  }
-  return %res : index
-}
-// CHECK-NEXT:  %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK-NEXT:    scf.yield %[[c1]] : index
-// CHECK-NEXT:  } else {
-// CHECK-NEXT:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK-NEXT:    scf.yield %[[c4]] : index
-// CHECK-NEXT:  }
-// CHECK-NEXT:  return %[[if]] : index
-// CHECK-NEXT:}
-
-// -----
-
 // CHECK-LABEL: @replace_if_with_cond1
 func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
   %true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
new file mode 100644
index 0000000000000..99d113f672014
--- /dev/null
+++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+  %res = scf.if %arg0 -> index {
+    %res1 = scf.if %arg0 -> index {
+      %v1 = "test.get_some_value1"() : () -> index
+      scf.yield %v1 : index
+    } else {
+      %v2 = "test.get_some_value2"() : () -> index
+      scf.yield %v2 : index
+    }
+    scf.yield %res1 : index
+  } else {
+    %res2 = scf.if %arg0 -> index {
+      %v3 = "test.get_some_value3"() : () -> index
+      scf.yield %v3 : index
+    } else {
+      %v4 = "test.get_some_value4"() : () -> index
+      scf.yield %v4 : index
+    }
+    scf.yield %res2 : index
+  }
+  return %res : index
+}
+// CHECK:  %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK:    scf.yield %[[c1]] : index
+// CHECK:  } else {
+// CHECK:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK:    scf.yield %[[c4]] : index
+// CHECK:  }
+// CHECK:  return %[[if]] : index
+// CHECK:}

``````````

</details>


https://github.com/llvm/llvm-project/pull/150278


More information about the Mlir-commits mailing list