[Mlir-commits] [mlir] [mlir][cf] Canonicalize block args with uniform incoming values (PR #183966)

Fedor Nikolaev llvmlistbot at llvm.org
Sat Feb 28 15:05:26 PST 2026


https://github.com/felichita created https://github.com/llvm/llvm-project/pull/183966

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination. First itteration

Idea from #182711

cc: @matthias-springer , @joker-eph 

>From 824c76c32289913c7bac1b1a0c0113b9c7587bde Mon Sep 17 00:00:00 2001
From: Fedor Nikolaev <fridrixnm at gmail.com>
Date: Sat, 28 Feb 2026 23:40:46 +0100
Subject: [PATCH] [mlir][cf] Canonicalize block args with uniform incoming
 values

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination.

Ref #182711
---
 .../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 83 ++++++++++++++++++-
 .../Dialect/ControlFlow/canonicalize.mlir     | 38 ++++++++-
 2 files changed, 116 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 0ce0d55f4397c..cdc44122068b3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -86,7 +86,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
   return failure();
 }
 
-// This side effect models "program termination". 
+// This side effect models "program termination".
 void AssertOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -204,9 +204,85 @@ static LogicalResult simplifyPassThroughBr(BranchOp op,
   return success();
 }
 
+///   %c = arith.constant 0 : i32
+///   cf.br ^bb1(%c : i32)      // pred 1
+///   cf.br ^bb1(%c : i32)      // pred 2
+/// ^bb1(%arg0: i32):
+///   use(%arg0)
+/// ->
+/// ^bb1(%arg0: i32):
+///   use(%c)                   // %arg0 has no uses and can be removed
+///
+/// If all incoming values for a block argument from all predecessors are the
+/// same SSA value, replace uses of the block argument with that value. This
+/// allows the block argument to be removed by dead code elimination.
+static bool simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter) {
+  if (dest->hasNoPredecessors() ||
+      llvm::hasSingleElement(dest->getPredecessors()))
+    return false;
+
+  bool changed = false;
+  for (BlockArgument arg : dest->getArguments()) {
+    if (arg.use_empty())
+      continue;
+
+    Value commonValue;
+    bool allSame = true;
+    for (Block *pred : dest->getPredecessors()) {
+      auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+      if (!branch) {
+        allSame = false;
+        break;
+      }
+
+      Value incoming;
+      for (unsigned i = 0; i < branch->getNumSuccessors(); ++i) {
+        if (branch->getSuccessor(i) != dest)
+          continue;
+        SuccessorOperands succOps = branch.getSuccessorOperands(i);
+        if (arg.getArgNumber() >= succOps.size()) {
+          allSame = false;
+          break;
+        }
+        incoming = succOps[arg.getArgNumber()];
+        break;
+      }
+      if (!incoming || (commonValue && commonValue != incoming)) {
+        allSame = false;
+        break;
+      }
+      commonValue = incoming;
+    }
+
+    if (allSame && commonValue && commonValue != arg) {
+      rewriter.replaceAllUsesWith(arg, commonValue);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+namespace {
+/// Replaces block arguments with a uniform incoming value across all
+/// predecessors of a CondBranchOp successor.
+struct SimplifyCondBranchBlockArgWithUniformIncomingValues
+    : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CondBranchOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (unsigned i = 0; i < op->getNumSuccessors(); ++i)
+      changed |= simplifyUniformBlockArgs(op->getSuccessor(i), rewriter);
+    return success(changed);
+  }
+};
+} // namespace
+
 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
-                 succeeded(simplifyPassThroughBr(op, rewriter)));
+                 succeeded(simplifyPassThroughBr(op, rewriter)) ||
+                 simplifyUniformBlockArgs(op.getDest(), rewriter));
 }
 
 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
@@ -492,7 +568,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
               SimplifyCondBranchIdenticalSuccessors,
               SimplifyCondBranchFromCondBranchOnSameCondition,
-              CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
+              CondBranchTruthPropagation, DropUnreachableCondBranch,
+              SimplifyCondBranchBlockArgWithUniformIncomingValues>(context);
 }
 
 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 8ddfeb7b0841c..5bcd76badea59 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -131,8 +131,7 @@ func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) ->
 // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
 func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
   // CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG2]]
-  // CHECK: %[[RES2:.*]] = arith.select %[[COND]], %[[ARG1]], %[[ARG2]]
-  // CHECK: return %[[RES]], %[[RES2]]
+  // CHECK: return %[[RES]], %[[ARG1]]
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
 
@@ -686,3 +685,38 @@ func.func @no_merge_self_arg_loop(%step: i1) -> i1 {
 ^exit(%result: i1):
   return %result : i1
 }
+
+// CHECK-LABEL: func @fold_uniform_branch_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[C:.*]]: i32
+func.func @fold_uniform_branch_block_arg(%cond: i1, %c: i32) -> i32 {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%c : i32)
+^bb2:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%c : i32)
+^bb3(%arg0: i32):
+  // CHECK: ^bb3:
+  // CHECK: return %[[C]]
+  return %arg0 : i32
+}
+
+// Verify that block arguments are not folded when incoming values differ
+// across predecessors.
+
+// CHECK-LABEL: func @no_fold_non_uniform_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32
+func.func @no_fold_non_uniform_block_arg(%cond: i1, %a: i32, %b: i32) -> i32 {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%a : i32)
+^bb2:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%b : i32)
+^bb3(%arg0: i32):
+  // CHECK: ^bb3(%[[ARG0:.*]]: i32):
+  // CHECK-NEXT: return %[[ARG0]]
+  return %arg0 : i32
+}



More information about the Mlir-commits mailing list