[Mlir-commits] [mlir] [mlir][cf] Canonicalize block args with uniform incoming values (PR #183966)

Fedor Nikolaev llvmlistbot at llvm.org
Tue Mar 3 05:11:16 PST 2026


https://github.com/felichita updated https://github.com/llvm/llvm-project/pull/183966

>From 84d36287db03150e82097e03ad961309eb64083d Mon Sep 17 00:00:00 2001
From: Fedor Nikolaev <fridrixnm at gmail.com>
Date: Sat, 28 Feb 2026 23:40:46 +0100
Subject: [PATCH] [mlir][cf] Canonicalize block args with uniform incoming
 values

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination.

Ref #182711
---
 .../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 79 ++++++++++++++++-
 .../Dialect/ControlFlow/canonicalize.mlir     | 86 +++++++++++++++++++
 2 files changed, 162 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 0ce0d55f4397c..6888340aeea47 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -86,7 +86,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
   return failure();
 }
 
-// This side effect models "program termination". 
+// This side effect models "program termination".
 void AssertOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -204,9 +204,81 @@ static LogicalResult simplifyPassThroughBr(BranchOp op,
   return success();
 }
 
+/// If all incoming values for a block argument from all predecessors are the
+/// same SSA value, replace uses of the block argument with that value. This
+/// allows the block argument to be removed by dead code elimination.
+///
+///   %c = arith.constant 0 : i32
+///   cf.br ^bb1(%c : i32)      // pred 1
+///   cf.br ^bb1(%c : i32)      // pred 2
+/// ^bb1(%arg0: i32):
+///   use(%arg0)
+/// ->
+/// ^bb1(%arg0: i32):
+///   use(%c)                   // %arg0 has no uses and can be removed
+///
+static bool simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter) {
+  if (dest->hasNoPredecessors() ||
+      llvm::hasSingleElement(dest->getPredecessors()))
+    return false;
+
+  bool changed = false;
+  for (BlockArgument arg : dest->getArguments()) {
+    if (arg.use_empty())
+      continue;
+
+    Value commonValue;
+    bool allSame = true;
+    for (Block *pred : dest->getPredecessors()) {
+      auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+      if (!branch) {
+        allSame = false;
+        break;
+      }
+
+      for (auto [i, succ] : llvm::enumerate(branch->getSuccessors())) {
+        if (succ != dest)
+          continue;
+        Value val = branch.getSuccessorOperands(i)[arg.getArgNumber()];
+        if (commonValue && commonValue != val) {
+          allSame = false;
+          break;
+        }
+        commonValue = val;
+      }
+
+      if (!allSame)
+        break;
+    }
+
+    if (allSame && commonValue && commonValue != arg) {
+      rewriter.replaceAllUsesWith(arg, commonValue);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+namespace {
+/// Replaces block arguments with a uniform incoming value across all
+/// predecessors, for any op implementing BranchOpInterface.
+struct SimplifyUniformBlockArguments
+    : public OpInterfaceRewritePattern<BranchOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  LogicalResult matchAndRewrite(BranchOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (Block *succ : op->getSuccessors())
+      changed |= simplifyUniformBlockArgs(succ, rewriter);
+    return success(changed);
+  }
+};
+} // namespace
+
 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
-                 succeeded(simplifyPassThroughBr(op, rewriter)));
+                 succeeded(simplifyPassThroughBr(op, rewriter)) ||
+                 simplifyUniformBlockArgs(op.getDest(), rewriter));
 }
 
 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
@@ -492,7 +564,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
               SimplifyCondBranchIdenticalSuccessors,
               SimplifyCondBranchFromCondBranchOnSameCondition,
-              CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
+              CondBranchTruthPropagation, DropUnreachableCondBranch,
+              SimplifyUniformBlockArguments>(context);
 }
 
 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 8ddfeb7b0841c..555fa521af2b7 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -686,3 +686,89 @@ func.func @no_merge_self_arg_loop(%step: i1) -> i1 {
 ^exit(%result: i1):
   return %result : i1
 }
+
+// Verify that block arguments are replaced with a uniform incoming value
+// when all predecessors pass the same SSA value
+
+// CHECK-LABEL: func @fold_uniform_branch_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[C:.*]]: i32
+func.func @fold_uniform_branch_block_arg(%cond: i1, %c: i32) -> i32 {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%c : i32)
+^bb2:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%c : i32)
+^bb3(%arg0: i32):
+  // CHECK: ^bb3:
+  // CHECK: return %[[C]]
+  return %arg0 : i32
+}
+
+// Verify that block arguments are not folded when incoming values differ
+// across predecessors.
+
+// CHECK-LABEL: func @no_fold_non_uniform_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32
+func.func @no_fold_non_uniform_block_arg(%cond: i1, %a: i32, %b: i32) -> i32 {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%a : i32)
+^bb2:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%b : i32)
+^bb3(%arg0: i32):
+  // CHECK: ^bb3(%[[ARG0:.*]]: i32):
+  // CHECK-NEXT: return %[[ARG0]]
+  return %arg0 : i32
+}
+
+// Verify no folding when the same block appears multiple times as a
+// successor with different operands.
+
+// CHECK-LABEL: func @no_fold_same_dest_different_args
+func.func @no_fold_same_dest_different_args(%flag: i32, %a: i32, %b: i32) -> i32 {
+  cf.switch %flag : i32, [
+    default: ^bb1(%a : i32),
+    0: ^bb1(%b : i32)
+  ]
+^bb1(%arg0: i32):
+  // CHECK: ^bb1(%[[ARG0:.*]]: i32):
+  // CHECK-NEXT: return %[[ARG0]]
+  return %arg0 : i32
+}
+
+// Verify no folding when a predecessor has an unknown terminator.
+
+// CHECK-LABEL: func @no_fold_unknown_terminator
+func.func @no_fold_unknown_terminator(%a: i32) -> i32 {
+  cf.br ^bb1
+^bb1:
+  "foo.two_successors"()[^bb2, ^bb3] : () -> ()
+^bb2:
+  // CHECK: ^bb2(%[[ARG0:.*]]: i32):
+  cf.br ^bb3(%a : i32)
+^bb3(%arg0: i32):
+  // CHECK-NEXT: return %[[ARG0]]
+  return %arg0 : i32
+}
+
+// Verify that unused block arguments are skipped and only used arguments
+// with uniform incoming values are folded.
+
+// CHECK-LABEL: func @skip_unused_block_arg
+func.func @skip_unused_block_arg(%flag: i32, %a: i32, %b: i32) -> i32 {
+  "foo.pred"()[^bb1, ^bb2] : () -> ()
+^bb1:
+  cf.br ^bb3(%b, %a : i32, i32)
+^bb2:
+  cf.br ^bb3(%a, %a : i32, i32)
+^bb3(%arg0: i32, %arg1: i32):
+  "foo.use"(%arg0) : (i32) -> ()
+  // CHECK: ^bb3(%[[ARG0:.*]]: i32):
+  // CHECK-NEXT: "foo.use"(%[[ARG0]])
+  // CHECK-NEXT: return %[[A:.*]]
+  return %arg1 : i32
+}



More information about the Mlir-commits mailing list