[Mlir-commits] [mlir] 9d11acc - [mlir] move if-condition propagation to a standalone pass (#150278)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 23 12:02:44 PDT 2025
Author: Oleksandr "Alex" Zinenko
Date: 2025-07-23T21:02:40+02:00
New Revision: 9d11accf95db0ed08bd3181c25dd75fc793d089d
URL: https://github.com/llvm/llvm-project/commit/9d11accf95db0ed08bd3181c25dd75fc793d089d
DIFF: https://github.com/llvm/llvm-project/commit/9d11accf95db0ed08bd3181c25dd75fc793d089d.diff
LOG: [mlir] move if-condition propagation to a standalone pass (#150278)
This offers a significant speedup over running this as a
canonicalizaiton pattern, up to 10x improvement when running on large
(>100k operations) inputs coming from Polygeist.
It is also not clear whether this transformation is a reasonable
canonicalization as it performs non-local rewrites.
Added:
mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
mlir/test/Dialect/SCF/if-cond-prop.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..ca2510bb53af9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
let constructor = "mlir::createForLoopSpecializationPass()";
}
+def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
+ let summary = "Replace usages of if condition with true/false constants in "
+ "the conditional regions";
+ let dependentDialects = ["arith::ArithDialect"];
+}
+
def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
let summary = "Fuse adjacent parallel loops";
let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index df41eba4ef533..72ab4b13d2b78 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2414,65 +2414,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
};
-/// Allow the true region of an if to assume the condition is true
-/// and vice versa. For example:
-///
-/// scf.if %cmp {
-/// print(%cmp)
-/// }
-///
-/// becomes
-///
-/// scf.if %cmp {
-/// print(true)
-/// }
-///
-struct ConditionPropagation : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- // Early exit if the condition is constant since replacing a constant
- // in the body with another constant isn't a simplification.
- if (matchPattern(op.getCondition(), m_Constant()))
- return failure();
-
- bool changed = false;
- mlir::Type i1Ty = rewriter.getI1Type();
-
- // These variables serve to prevent creating duplicate constants
- // and hold constant true or false values.
- Value constantTrue = nullptr;
- Value constantFalse = nullptr;
-
- for (OpOperand &use :
- llvm::make_early_inc_range(op.getCondition().getUses())) {
- if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantTrue); });
- } else if (op.getElseRegion().isAncestor(
- use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantFalse); });
- }
- }
-
- return success(changed);
- }
-};
-
/// Remove any statements from an if that are equivalent to the condition
/// or its negation. For example:
///
@@ -2854,9 +2795,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
- ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, RemoveUnusedResults,
+ results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
+ RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..6d3bafbbc90e4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
+ IfConditionPropagation.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
new file mode 100644
index 0000000000000..bdc51296ef9f2
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
@@ -0,0 +1,98 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+ llvm::SmallPtrSet<Region *, 8> &visited) {
+ if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+ llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+ // Visit the "then" region, collect children.
+ for (Block &block : scfIf.getThenRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, thenChildren);
+ }
+ }
+
+ // Visit the "else" region, collect children.
+ for (Block &block : scfIf.getElseRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, elseChildren);
+ }
+ }
+
+ // Update uses to point to constants instead.
+ OpBuilder builder(scfIf);
+ Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ /*value=*/true, /*width=*/1);
+ Value falseValue =
+ arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ /*value=*/false, /*width=*/1);
+
+ for (OpOperand &use : scfIf.getCondition().getUses()) {
+ if (thenChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(trueValue);
+ else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(falseValue);
+ }
+ if (trueValue.getUses().empty())
+ trueValue.getDefiningOp()->erase();
+ if (falseValue.getUses().empty())
+ falseValue.getDefiningOp()->erase();
+
+ // Append the two lists of children and return them.
+ visited.insert_range(thenChildren);
+ visited.insert_range(elseChildren);
+ return;
+ }
+
+ for (Region ®ion : root->getRegions()) {
+ for (Block &block : region) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, visited);
+ }
+ }
+ }
+}
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions
+static void propagateIfConditions(Operation *root) {
+ llvm::SmallPtrSet<Region *, 8> visited;
+ propagateIfConditionsImpl(root, visited);
+}
+
+namespace {
+/// Pass entrypoint.
+struct SCFIfConditionPropagationPass
+ : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
+ void runOnOperation() override { propagateIfConditions(getOperation()); }
+};
+} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..12d30e17f4a8f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
// -----
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
- %res = scf.if %arg0 -> index {
- %res1 = scf.if %arg0 -> index {
- %v1 = "test.get_some_value1"() : () -> index
- scf.yield %v1 : index
- } else {
- %v2 = "test.get_some_value2"() : () -> index
- scf.yield %v2 : index
- }
- scf.yield %res1 : index
- } else {
- %res2 = scf.if %arg0 -> index {
- %v3 = "test.get_some_value3"() : () -> index
- scf.yield %v3 : index
- } else {
- %v4 = "test.get_some_value4"() : () -> index
- scf.yield %v4 : index
- }
- scf.yield %res2 : index
- }
- return %res : index
-}
-// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK-NEXT: scf.yield %[[c1]] : index
-// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK-NEXT: scf.yield %[[c4]] : index
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[if]] : index
-// CHECK-NEXT:}
-
-// -----
-
// CHECK-LABEL: @replace_if_with_cond1
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
new file mode 100644
index 0000000000000..99d113f672014
--- /dev/null
+++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+ %res = scf.if %arg0 -> index {
+ %res1 = scf.if %arg0 -> index {
+ %v1 = "test.get_some_value1"() : () -> index
+ scf.yield %v1 : index
+ } else {
+ %v2 = "test.get_some_value2"() : () -> index
+ scf.yield %v2 : index
+ }
+ scf.yield %res1 : index
+ } else {
+ %res2 = scf.if %arg0 -> index {
+ %v3 = "test.get_some_value3"() : () -> index
+ scf.yield %v3 : index
+ } else {
+ %v4 = "test.get_some_value4"() : () -> index
+ scf.yield %v4 : index
+ }
+ scf.yield %res2 : index
+ }
+ return %res : index
+}
+// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK: scf.yield %[[c1]] : index
+// CHECK: } else {
+// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK: scf.yield %[[c4]] : index
+// CHECK: }
+// CHECK: return %[[if]] : index
+// CHECK:}
More information about the Mlir-commits
mailing list