[Mlir-commits] [mlir] [mlir][cf] Canonicalize block args with uniform incoming values (PR #183966)
Fedor Nikolaev
llvmlistbot at llvm.org
Sat Feb 28 15:05:26 PST 2026
https://github.com/felichita created https://github.com/llvm/llvm-project/pull/183966
Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination. First itteration
Idea from #182711
cc: @matthias-springer , @joker-eph
>From 824c76c32289913c7bac1b1a0c0113b9c7587bde Mon Sep 17 00:00:00 2001
From: Fedor Nikolaev <fridrixnm at gmail.com>
Date: Sat, 28 Feb 2026 23:40:46 +0100
Subject: [PATCH] [mlir][cf] Canonicalize block args with uniform incoming
values
Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination.
Ref #182711
---
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 83 ++++++++++++++++++-
.../Dialect/ControlFlow/canonicalize.mlir | 38 ++++++++-
2 files changed, 116 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 0ce0d55f4397c..cdc44122068b3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -86,7 +86,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return failure();
}
-// This side effect models "program termination".
+// This side effect models "program termination".
void AssertOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -204,9 +204,85 @@ static LogicalResult simplifyPassThroughBr(BranchOp op,
return success();
}
+/// %c = arith.constant 0 : i32
+/// cf.br ^bb1(%c : i32) // pred 1
+/// cf.br ^bb1(%c : i32) // pred 2
+/// ^bb1(%arg0: i32):
+/// use(%arg0)
+/// ->
+/// ^bb1(%arg0: i32):
+/// use(%c) // %arg0 has no uses and can be removed
+///
+/// If all incoming values for a block argument from all predecessors are the
+/// same SSA value, replace uses of the block argument with that value. This
+/// allows the block argument to be removed by dead code elimination.
+static bool simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter) {
+ if (dest->hasNoPredecessors() ||
+ llvm::hasSingleElement(dest->getPredecessors()))
+ return false;
+
+ bool changed = false;
+ for (BlockArgument arg : dest->getArguments()) {
+ if (arg.use_empty())
+ continue;
+
+ Value commonValue;
+ bool allSame = true;
+ for (Block *pred : dest->getPredecessors()) {
+ auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+ if (!branch) {
+ allSame = false;
+ break;
+ }
+
+ Value incoming;
+ for (unsigned i = 0; i < branch->getNumSuccessors(); ++i) {
+ if (branch->getSuccessor(i) != dest)
+ continue;
+ SuccessorOperands succOps = branch.getSuccessorOperands(i);
+ if (arg.getArgNumber() >= succOps.size()) {
+ allSame = false;
+ break;
+ }
+ incoming = succOps[arg.getArgNumber()];
+ break;
+ }
+ if (!incoming || (commonValue && commonValue != incoming)) {
+ allSame = false;
+ break;
+ }
+ commonValue = incoming;
+ }
+
+ if (allSame && commonValue && commonValue != arg) {
+ rewriter.replaceAllUsesWith(arg, commonValue);
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+namespace {
+/// Replaces block arguments with a uniform incoming value across all
+/// predecessors of a CondBranchOp successor.
+struct SimplifyCondBranchBlockArgWithUniformIncomingValues
+ : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ for (unsigned i = 0; i < op->getNumSuccessors(); ++i)
+ changed |= simplifyUniformBlockArgs(op->getSuccessor(i), rewriter);
+ return success(changed);
+ }
+};
+} // namespace
+
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
- succeeded(simplifyPassThroughBr(op, rewriter)));
+ succeeded(simplifyPassThroughBr(op, rewriter)) ||
+ simplifyUniformBlockArgs(op.getDest(), rewriter));
}
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
@@ -492,7 +568,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
- CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
+ CondBranchTruthPropagation, DropUnreachableCondBranch,
+ SimplifyCondBranchBlockArgWithUniformIncomingValues>(context);
}
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 8ddfeb7b0841c..5bcd76badea59 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -131,8 +131,7 @@ func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) ->
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG2]]
- // CHECK: %[[RES2:.*]] = arith.select %[[COND]], %[[ARG1]], %[[ARG2]]
- // CHECK: return %[[RES]], %[[RES2]]
+ // CHECK: return %[[RES]], %[[ARG1]]
cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
@@ -686,3 +685,38 @@ func.func @no_merge_self_arg_loop(%step: i1) -> i1 {
^exit(%result: i1):
return %result : i1
}
+
+// CHECK-LABEL: func @fold_uniform_branch_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[C:.*]]: i32
+func.func @fold_uniform_branch_block_arg(%cond: i1, %c: i32) -> i32 {
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%c : i32)
+^bb2:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%c : i32)
+^bb3(%arg0: i32):
+ // CHECK: ^bb3:
+ // CHECK: return %[[C]]
+ return %arg0 : i32
+}
+
+// Verify that block arguments are not folded when incoming values differ
+// across predecessors.
+
+// CHECK-LABEL: func @no_fold_non_uniform_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32
+func.func @no_fold_non_uniform_block_arg(%cond: i1, %a: i32, %b: i32) -> i32 {
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%a : i32)
+^bb2:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%b : i32)
+^bb3(%arg0: i32):
+ // CHECK: ^bb3(%[[ARG0:.*]]: i32):
+ // CHECK-NEXT: return %[[ARG0]]
+ return %arg0 : i32
+}
More information about the Mlir-commits
mailing list