[Mlir-commits] [mlir] [mlir] move if-condition propagation to a standalone pass (PR #150278)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 23 10:34:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist.
It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.
---
Full diff: https://github.com/llvm/llvm-project/pull/150278.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+6)
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-62)
- (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp (+96)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-35)
- (added) mlir/test/Dialect/SCF/if-cond-prop.mlir (+34)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..ca2510bb53af9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
let constructor = "mlir::createForLoopSpecializationPass()";
}
+def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
+ let summary = "Replace usages of if condition with true/false constants in "
+ "the conditional regions";
+ let dependentDialects = ["arith::ArithDialect"];
+}
+
def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
let summary = "Fuse adjacent parallel loops";
let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 00c31a1500e17..6cb61900928d6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2412,65 +2412,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
};
-/// Allow the true region of an if to assume the condition is true
-/// and vice versa. For example:
-///
-/// scf.if %cmp {
-/// print(%cmp)
-/// }
-///
-/// becomes
-///
-/// scf.if %cmp {
-/// print(true)
-/// }
-///
-struct ConditionPropagation : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- // Early exit if the condition is constant since replacing a constant
- // in the body with another constant isn't a simplification.
- if (matchPattern(op.getCondition(), m_Constant()))
- return failure();
-
- bool changed = false;
- mlir::Type i1Ty = rewriter.getI1Type();
-
- // These variables serve to prevent creating duplicate constants
- // and hold constant true or false values.
- Value constantTrue = nullptr;
- Value constantFalse = nullptr;
-
- for (OpOperand &use :
- llvm::make_early_inc_range(op.getCondition().getUses())) {
- if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantTrue); });
- } else if (op.getElseRegion().isAncestor(
- use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantFalse); });
- }
- }
-
- return success(changed);
- }
-};
-
/// Remove any statements from an if that are equivalent to the condition
/// or its negation. For example:
///
@@ -2852,9 +2793,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
- ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, RemoveUnusedResults,
+ results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
+ RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..6d3bafbbc90e4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
+ IfConditionPropagation.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
new file mode 100644
index 0000000000000..be8d0e805a7a4
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
@@ -0,0 +1,96 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+ llvm::SmallPtrSet<Region *, 8> &visited) {
+ if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+ llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+ // Visit the "then" region, collect children.
+ for (Block &block : scfIf.getThenRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, thenChildren);
+ }
+ }
+
+ // Visit the "else" region, collect children.
+ for (Block &block : scfIf.getElseRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, elseChildren);
+ }
+ }
+
+ // Update uses to point to constants instead.
+ OpBuilder builder(scfIf);
+ Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ builder.getBoolAttr(true));
+ Value falseValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ builder.getBoolAttr(false));
+
+ for (OpOperand &use : scfIf.getCondition().getUses()) {
+ if (thenChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(trueValue);
+ else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(falseValue);
+ }
+ if (trueValue.getUses().empty())
+ trueValue.getDefiningOp()->erase();
+ if (falseValue.getUses().empty())
+ falseValue.getDefiningOp()->erase();
+
+ // Append the two lists of children and return them.
+ visited.insert_range(thenChildren);
+ visited.insert_range(elseChildren);
+ return;
+ }
+
+ for (Region ®ion : root->getRegions()) {
+ for (Block &block : region) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, visited);
+ }
+ }
+ }
+}
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions
+static void propagateIfConditions(Operation *root) {
+ llvm::SmallPtrSet<Region *, 8> visited;
+ propagateIfConditionsImpl(root, visited);
+}
+
+namespace {
+/// Pass entrypoint.
+struct SCFIfConditionPropagationPass
+ : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
+ void runOnOperation() override { propagateIfConditions(getOperation()); }
+};
+} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..12d30e17f4a8f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
// -----
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
- %res = scf.if %arg0 -> index {
- %res1 = scf.if %arg0 -> index {
- %v1 = "test.get_some_value1"() : () -> index
- scf.yield %v1 : index
- } else {
- %v2 = "test.get_some_value2"() : () -> index
- scf.yield %v2 : index
- }
- scf.yield %res1 : index
- } else {
- %res2 = scf.if %arg0 -> index {
- %v3 = "test.get_some_value3"() : () -> index
- scf.yield %v3 : index
- } else {
- %v4 = "test.get_some_value4"() : () -> index
- scf.yield %v4 : index
- }
- scf.yield %res2 : index
- }
- return %res : index
-}
-// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK-NEXT: scf.yield %[[c1]] : index
-// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK-NEXT: scf.yield %[[c4]] : index
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[if]] : index
-// CHECK-NEXT:}
-
-// -----
-
// CHECK-LABEL: @replace_if_with_cond1
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
new file mode 100644
index 0000000000000..99d113f672014
--- /dev/null
+++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+ %res = scf.if %arg0 -> index {
+ %res1 = scf.if %arg0 -> index {
+ %v1 = "test.get_some_value1"() : () -> index
+ scf.yield %v1 : index
+ } else {
+ %v2 = "test.get_some_value2"() : () -> index
+ scf.yield %v2 : index
+ }
+ scf.yield %res1 : index
+ } else {
+ %res2 = scf.if %arg0 -> index {
+ %v3 = "test.get_some_value3"() : () -> index
+ scf.yield %v3 : index
+ } else {
+ %v4 = "test.get_some_value4"() : () -> index
+ scf.yield %v4 : index
+ }
+ scf.yield %res2 : index
+ }
+ return %res : index
+}
+// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK: scf.yield %[[c1]] : index
+// CHECK: } else {
+// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK: scf.yield %[[c4]] : index
+// CHECK: }
+// CHECK: return %[[if]] : index
+// CHECK:}
``````````
</details>
https://github.com/llvm/llvm-project/pull/150278
More information about the Mlir-commits
mailing list