[Mlir-commits] [mlir] c5062d7 - Revert "[mlir] move if-condition propagation to a standalone pass" (#159535)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 18 02:54:58 PDT 2025
Author: Mehdi Amini
Date: 2025-09-18T09:54:53Z
New Revision: c5062d7d6358d73931b4791c77500f476606b003
URL: https://github.com/llvm/llvm-project/commit/c5062d7d6358d73931b4791c77500f476606b003
DIFF: https://github.com/llvm/llvm-project/commit/c5062d7d6358d73931b4791c77500f476606b003.diff
LOG: Revert "[mlir] move if-condition propagation to a standalone pass" (#159535)
Reverts llvm/llvm-project#150278
Multiple post-merge comment remained undressed, and some more
fundamental issues were also reported in #159165
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
mlir/test/Dialect/SCF/if-cond-prop.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 8b891aa374b58..3ac651f53880c 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,12 +41,6 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
let constructor = "mlir::createForLoopSpecializationPass()";
}
-def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
- let summary = "Replace usages of if condition with true/false constants in "
- "the conditional regions";
- let dependentDialects = ["arith::ArithDialect"];
-}
-
def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
let summary = "Fuse adjacent parallel loops";
let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ae55eaded0554..a9da6c2c8320a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2453,6 +2453,65 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
};
+/// Allow the true region of an if to assume the condition is true
+/// and vice versa. For example:
+///
+/// scf.if %cmp {
+/// print(%cmp)
+/// }
+///
+/// becomes
+///
+/// scf.if %cmp {
+/// print(true)
+/// }
+///
+struct ConditionPropagation : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ // Early exit if the condition is constant since replacing a constant
+ // in the body with another constant isn't a simplification.
+ if (matchPattern(op.getCondition(), m_Constant()))
+ return failure();
+
+ bool changed = false;
+ mlir::Type i1Ty = rewriter.getI1Type();
+
+ // These variables serve to prevent creating duplicate constants
+ // and hold constant true or false values.
+ Value constantTrue = nullptr;
+ Value constantFalse = nullptr;
+
+ for (OpOperand &use :
+ llvm::make_early_inc_range(op.getCondition().getUses())) {
+ if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+ changed = true;
+
+ if (!constantTrue)
+ constantTrue = rewriter.create<arith::ConstantOp>(
+ op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&]() { use.set(constantTrue); });
+ } else if (op.getElseRegion().isAncestor(
+ use.getOwner()->getParentRegion())) {
+ changed = true;
+
+ if (!constantFalse)
+ constantFalse = rewriter.create<arith::ConstantOp>(
+ op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&]() { use.set(constantFalse); });
+ }
+ }
+
+ return success(changed);
+ }
+};
+
/// Remove any statements from an if that are equivalent to the condition
/// or its negation. For example:
///
@@ -2835,8 +2894,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
- RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
+ results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
+ ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
+ RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a07d9d4953d19..a9ffa9dc208a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,7 +4,6 @@ add_mlir_dialect_library(MLIRSCFTransforms
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
- IfConditionPropagation.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
deleted file mode 100644
index bdc51296ef9f2..0000000000000
--- a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
+++ /dev/null
@@ -1,98 +0,0 @@
-//===- IfConditionPropagation.cpp -----------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains a pass for constant propagation of the condition of an
-// `scf.if` into its then and else regions as true and false respectively.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Passes.h"
-
-using namespace mlir;
-
-namespace mlir {
-#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
-#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
-} // namespace mlir
-
-/// Traverses the IR recursively (on region tree) and updates the uses of a
-/// value also as the condition of an `scf.if` to either `true` or `false`
-/// constants in the `then` and `else regions. This is done as a single
-/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
-/// traversing, the function maintains the set of visited regions to quickly
-/// identify whether the value belong to a region that is known to be nested in
-/// the `then` or `else` branch of a specific loop.
-static void propagateIfConditionsImpl(Operation *root,
- llvm::SmallPtrSet<Region *, 8> &visited) {
- if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
- llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
- // Visit the "then" region, collect children.
- for (Block &block : scfIf.getThenRegion()) {
- for (Operation &op : block) {
- propagateIfConditionsImpl(&op, thenChildren);
- }
- }
-
- // Visit the "else" region, collect children.
- for (Block &block : scfIf.getElseRegion()) {
- for (Operation &op : block) {
- propagateIfConditionsImpl(&op, elseChildren);
- }
- }
-
- // Update uses to point to constants instead.
- OpBuilder builder(scfIf);
- Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
- /*value=*/true, /*width=*/1);
- Value falseValue =
- arith::ConstantIntOp::create(builder, scfIf.getLoc(),
- /*value=*/false, /*width=*/1);
-
- for (OpOperand &use : scfIf.getCondition().getUses()) {
- if (thenChildren.contains(use.getOwner()->getParentRegion()))
- use.set(trueValue);
- else if (elseChildren.contains(use.getOwner()->getParentRegion()))
- use.set(falseValue);
- }
- if (trueValue.getUses().empty())
- trueValue.getDefiningOp()->erase();
- if (falseValue.getUses().empty())
- falseValue.getDefiningOp()->erase();
-
- // Append the two lists of children and return them.
- visited.insert_range(thenChildren);
- visited.insert_range(elseChildren);
- return;
- }
-
- for (Region ®ion : root->getRegions()) {
- for (Block &block : region) {
- for (Operation &op : block) {
- propagateIfConditionsImpl(&op, visited);
- }
- }
- }
-}
-
-/// Traverses the IR recursively (on region tree) and updates the uses of a
-/// value also as the condition of an `scf.if` to either `true` or `false`
-/// constants in the `then` and `else regions
-static void propagateIfConditions(Operation *root) {
- llvm::SmallPtrSet<Region *, 8> visited;
- propagateIfConditionsImpl(root, visited);
-}
-
-namespace {
-/// Pass entrypoint.
-struct SCFIfConditionPropagationPass
- : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
- void runOnOperation() override { propagateIfConditions(getOperation()); }
-};
-} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 5e89f74075252..2bec63672e783 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,6 +867,41 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
// -----
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+ %res = scf.if %arg0 -> index {
+ %res1 = scf.if %arg0 -> index {
+ %v1 = "test.get_some_value1"() : () -> index
+ scf.yield %v1 : index
+ } else {
+ %v2 = "test.get_some_value2"() : () -> index
+ scf.yield %v2 : index
+ }
+ scf.yield %res1 : index
+ } else {
+ %res2 = scf.if %arg0 -> index {
+ %v3 = "test.get_some_value3"() : () -> index
+ scf.yield %v3 : index
+ } else {
+ %v4 = "test.get_some_value4"() : () -> index
+ scf.yield %v4 : index
+ }
+ scf.yield %res2 : index
+ }
+ return %res : index
+}
+// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK-NEXT: scf.yield %[[c1]] : index
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK-NEXT: scf.yield %[[c4]] : index
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[if]] : index
+// CHECK-NEXT:}
+
+// -----
+
// CHECK-LABEL: @replace_if_with_cond1
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
deleted file mode 100644
index 99d113f672014..0000000000000
--- a/mlir/test/Dialect/SCF/if-cond-prop.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
-
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
- %res = scf.if %arg0 -> index {
- %res1 = scf.if %arg0 -> index {
- %v1 = "test.get_some_value1"() : () -> index
- scf.yield %v1 : index
- } else {
- %v2 = "test.get_some_value2"() : () -> index
- scf.yield %v2 : index
- }
- scf.yield %res1 : index
- } else {
- %res2 = scf.if %arg0 -> index {
- %v3 = "test.get_some_value3"() : () -> index
- scf.yield %v3 : index
- } else {
- %v4 = "test.get_some_value4"() : () -> index
- scf.yield %v4 : index
- }
- scf.yield %res2 : index
- }
- return %res : index
-}
-// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK: scf.yield %[[c1]] : index
-// CHECK: } else {
-// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK: scf.yield %[[c4]] : index
-// CHECK: }
-// CHECK: return %[[if]] : index
-// CHECK:}
More information about the Mlir-commits
mailing list