[Mlir-commits] [mlir] [mlir][RemoveDeadValues] Simplify branch op handling using ub.poison (PR #182711)

Fedor Nikolaev llvmlistbot at llvm.org
Thu May 14 06:43:09 PDT 2026


https://github.com/felichita updated https://github.com/llvm/llvm-project/pull/182711

>From 5635748145a61ba5b734845dbd169e45c250b288 Mon Sep 17 00:00:00 2001
From: Fedor Nikolaev <fridrixnm at gmail.com>
Date: Sun, 22 Feb 2026 00:01:54 +0100
Subject: [PATCH] [mlir][RemoveDeadValues] Simplify branch op handling using
 ub.poison

Replace direct operand erasure in processBranchOp with ub.poison
replacement, unifying the approach with region branch op handling.
Dead successor operands are replaced with ub.poison and block
arguments are left intact, relying on the canonicalizer to remove them
once all incoming operands are poison.

Fixes #182263
---
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 157 +++++++------------
 mlir/test/Transforms/remove-dead-values.mlir |  59 ++++++-
 2 files changed, 110 insertions(+), 106 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index f0a210a2ededb..3881c2f189724 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -19,7 +19,12 @@
 // (C) Removes unneccesary operands, results, region arguments, and region
 // terminator operands of region branch ops, and,
 // (D) Removes simple and region branch ops that have all non-live results and
-// don't affect memory in any way.
+// don't affect memory in any way,
+// (E) Replaces dead operands of branch ops with `ub.poison`. When the
+//     `canonicalize` option is enabled, the pass also runs canonicalization
+//     patterns on branch ops, region branch ops, and return-like ops,
+//     collected via BFS over successor block terminators, so that dead
+//     block arguments can be eliminated where possible.
 //
 // Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
 // region branch op, branch op, region branch terminator op, or return-like.
@@ -96,24 +101,11 @@ struct OperandsToCleanup {
   bool replaceWithPoison = false;
 };
 
-struct BlockArgsToCleanup {
-  Block *b;
-  BitVector nonLiveArgs;
-};
-
-struct SuccessorOperandsToCleanup {
-  BranchOpInterface branch;
-  unsigned successorIndex;
-  BitVector nonLiveOperands;
-};
-
 struct RDVFinalCleanupList {
   SmallVector<Operation *> operations;
   SmallVector<FunctionToCleanUp> functions;
   SmallVector<OperandsToCleanup> operands;
   SmallVector<ResultsToCleanup> results;
-  SmallVector<BlockArgsToCleanup> blocks;
-  SmallVector<SuccessorOperandsToCleanup> successorOperands;
 };
 
 // Some helper functions...
@@ -470,11 +462,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 ///
 /// Otherwise, iterate through each successor block of `branchOp`.
 /// (1) For each successor block, gather all operands from all successors.
-/// (2) Fetch their associated liveness analysis data and collect for future
-///     removal.
-/// (3) Identify and collect the dead operands from the successor block
-///     as well as their corresponding arguments.
-
+/// (2) Determine which operands are dead using liveness analysis.
+/// (3) Replace dead successor operands with ub.poison instead of erasing them.
+///     Block arguments are left intact — the canonicalizer will remove them
+///     once it sees all incoming operands are poison.
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
@@ -484,7 +475,8 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
   BitVector deadNonForwardedOperands =
       markLives(branchOp->getOperands(), nonLiveSet, la).flip();
   unsigned numSuccessors = branchOp->getNumSuccessors();
-  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+
+  for (unsigned succIdx : llvm::seq<unsigned>(0, numSuccessors)) {
     SuccessorOperands successorOperands =
         branchOp.getSuccessorOperands(succIdx);
     // Remove all non-forwarded operands from the bit vector.
@@ -496,28 +488,20 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
     return;
   }
 
-  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
-    Block *successorBlock = branchOp->getSuccessor(succIdx);
-
-    // Do (1)
-    SuccessorOperands successorOperands =
-        branchOp.getSuccessorOperands(succIdx);
-    SmallVector<Value> operandValues;
-    for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
-         ++operandIdx) {
-      operandValues.push_back(successorOperands[operandIdx]);
+  // For each successor, find dead forwarded operands and
+  // schedule them for replacement with ub.poison.
+  BitVector opNonLive(branchOp->getNumOperands(), false);
+  for (unsigned succIdx : llvm::seq<unsigned>(0, numSuccessors)) {
+    for (OpOperand &opOperand :
+         branchOp.getSuccessorOperands(succIdx).getMutableForwardedOperands()) {
+      if (!hasLive(opOperand.get(), nonLiveSet, la))
+        opNonLive.set(opOperand.getOperandNumber());
     }
-
-    // Do (2)
-    BitVector successorNonLive =
-        markLives(operandValues, nonLiveSet, la).flip();
-    collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
-                         successorNonLive);
-
-    // Do (3)
-    cl.blocks.push_back({successorBlock, successorNonLive});
-    cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
   }
+
+  if (opNonLive.any())
+    cl.operands.push_back({branchOp.getOperation(), opNonLive,
+                           /*callee=*/nullptr, /*replaceWithPoison=*/true});
 }
 
 /// Create ub.poison ops for the given values. If a value has no uses, return
@@ -558,56 +542,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
   TrackingListener listener;
   IRRewriter rewriter(ctx, &listener);
 
-  // 1. Blocks, We must remove the block arguments and successor operands before
-  // deleting the operation, as they may reside in the region operation.
-  LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
-  for (auto &b : list.blocks) {
-    // blocks that are accessed via multiple codepaths processed once
-    if (b.b->getNumArguments() != b.nonLiveArgs.size())
-      continue;
-    LDBG_OS([&](raw_ostream &os) {
-      os << "Erasing non-live arguments [";
-      llvm::interleaveComma(b.nonLiveArgs.set_bits(), os);
-      os << "] from block #" << b.b->computeBlockNumber() << " in region #"
-         << b.b->getParent()->getRegionNumber() << " of operation "
-         << OpWithFlags(b.b->getParent()->getParentOp(),
-                        OpPrintingFlags().skipRegions().printGenericOpForm());
-    });
-    // Note: Iterate from the end to make sure that that indices of not yet
-    // processes arguments do not change.
-    for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
-      if (!b.nonLiveArgs[i])
-        continue;
-      b.b->getArgument(i).dropAllUses();
-      b.b->eraseArgument(i);
-    }
-  }
-
-  // 2. Successor Operands
-  LDBG() << "Cleaning up " << list.successorOperands.size()
-         << " successor operand lists";
-  for (auto &op : list.successorOperands) {
-    SuccessorOperands successorOperands =
-        op.branch.getSuccessorOperands(op.successorIndex);
-    // blocks that are accessed via multiple codepaths processed once
-    if (successorOperands.size() != op.nonLiveOperands.size())
-      continue;
-    LDBG_OS([&](raw_ostream &os) {
-      os << "Erasing non-live successor operands [";
-      llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
-      os << "] from successor " << op.successorIndex << " of branch: "
-         << OpWithFlags(op.branch.getOperation(),
-                        OpPrintingFlags().skipRegions().printGenericOpForm());
-    });
-    // it iterates backwards because erase invalidates all successor indexes
-    for (int i = successorOperands.size() - 1; i >= 0; --i) {
-      if (!op.nonLiveOperands[i])
-        continue;
-      successorOperands.erase(i);
-    }
-  }
-
-  // 3. Functions
+  // 1. Functions
   LDBG() << "Cleaning up " << list.functions.size() << " functions";
   // Record which function arguments were erased so we can shrink call-site
   // argument segments for CallOpInterface operations (e.g. ops using
@@ -641,7 +576,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     (void)f.funcOp.eraseResults(f.nonLiveRets);
   }
 
-  // 4. Operands
+  // 2. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperandsToCleanup &o : list.operands) {
     // Handle call-specific cleanup only when we have a cached callee reference.
@@ -699,7 +634,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     }
   }
 
-  // 5. Results
+  // 3. Results
   LDBG() << "Cleaning up " << list.results.size() << " result lists";
   for (auto &r : list.results) {
     LDBG_OS([&](raw_ostream &os) {
@@ -712,7 +647,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
   }
 
-  // 6. Operations
+  // 4. Operations
   LDBG() << "Cleaning up " << list.operations.size() << " operations";
   for (Operation *op : list.operations) {
     LDBG() << "Erasing operation: "
@@ -749,7 +684,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     rewriter.eraseOp(op);
   }
 
-  // 7. Remove all dead poison ops.
+  // 5. Remove all dead poison ops.
   for (ub::PoisonOp poisonOp : listener.poisonOps) {
     if (poisonOp.use_empty())
       poisonOp.erase();
@@ -802,18 +737,42 @@ void RemoveDeadValues::runOnOperation() {
   if (!canonicalize)
     return;
 
-  // Canonicalize all region branch ops.
-  SmallVector<Operation *> opsToCanonicalize;
-  module->walk([&](RegionBranchOpInterface regionBranchOp) {
-    opsToCanonicalize.push_back(regionBranchOp.getOperation());
+  // Collect ops to canonicalize via BFS over successor blocks.
+  // Seed: all branch ops, region branch ops, and return-like ops.
+  // Then transitively follow successor block terminators to cover
+  // reachable blocks. Note: block arguments are not removed here;
+  // a follow-up --canonicalize pass is needed for full cleanup.
+  SmallVector<Operation *> worklist;
+  DenseSet<Operation *> visited;
+
+  module->walk([&](Operation *op) {
+    if (!isa<RegionBranchOpInterface, BranchOpInterface>(op) &&
+        !op->hasTrait<OpTrait::ReturnLike>())
+      return;
+    if (visited.insert(op).second)
+      worklist.push_back(op);
   });
-  // Collect all canonicalization patterns for region branch ops.
+
+  // BFS: follow successor block terminators transitively.
+  SmallVector<Operation *> opsToCanonicalize;
+  while (!worklist.empty()) {
+    Operation *op = worklist.pop_back_val();
+    opsToCanonicalize.push_back(op);
+    for (Block *succ : op->getSuccessors()) {
+      Operation *term = succ->getTerminator();
+      if (term && visited.insert(term).second)
+        worklist.push_back(term);
+    }
+  }
+
+  // Collect canonicalization patterns only for ops in the list.
   RewritePatternSet owningPatterns(context);
   DenseSet<RegisteredOperationName> populatedPatterns;
   for (Operation *op : opsToCanonicalize)
     if (std::optional<RegisteredOperationName> info = op->getRegisteredInfo())
       if (populatedPatterns.insert(*info).second)
         info->getCanonicalizationPatterns(owningPatterns, context);
+
   if (failed(applyOpPatternsGreedily(opsToCanonicalize,
                                      std::move(owningPatterns)))) {
     module->emitError("greedy pattern rewrite failed to converge");
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 64088ce15cd48..197262a631ff5 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -57,19 +57,18 @@ func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0:
   %non_live = arith.constant 0 : i32
   // CHECK-NOT: arith.constant
   cf.br ^bb1(%non_live : i32)
-  // CHECK: cf.br ^[[BB1:bb[0-9]+]]
+  // CHECK: cf.br ^[[BB1:bb[0-9]+]](%{{.*}} : i32)
 ^bb1(%non_live_1 : i32):
-  // CHECK: ^[[BB1]]:
+  // CHECK: ^[[BB1]](%{{.*}}: i32):
   %non_live_5 = arith.constant 1 : i32
   cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
-  // CHECK: cf.br ^[[BB3:bb[0-9]+]]
-  // CHECK-NOT: i32
+  // CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}}, %{{.*}} : i32, i32)
 ^bb3(%non_live_2 : i32, %non_live_6 : i32):
-  // CHECK: ^[[BB3]]:
+  // CHECK: ^[[BB3]](%{{.*}}: i32, %{{.*}}: i32):
   cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
-  // CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
+  // CHECK: cf.cond_br %arg0, ^[[BB1]](%{{.*}} : i32), ^[[BB4:bb[0-9]+]](%{{.*}} : i32)
 ^bb4(%non_live_4 : i32):
-  // CHECK: ^[[BB4]]:
+  // CHECK: ^[[BB4]](%{{.*}}: i32):
   return
 }
 
@@ -497,6 +496,52 @@ func.func @kernel(%arg0: memref<18xf32>) {
 
 // -----
 
+// Test that RemoveDeadValues does not crash when gpu.launch appears in a block
+// with multiple predecessors. The dead branch operand (%c20) must be replaced
+// with ub.poison, gpu.launch and its grid/block size operands must be
+// preserved, and the live block argument must remain intact.
+//
+// CHECK-LABEL: func.func @gpu_launch_in_multi_predecessor_block
+// CHECK: arith.constant true
+// CHECK: cf.cond_br
+// CHECK: arith.constant 10
+// CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}} : i64)
+// CHECK: ub.poison
+// CHECK: cf.br ^[[BB3]](%{{.*}} : i64)
+// CHECK: ^[[BB3]](%{{.*}}: i64):
+// CHECK: return
+// CHECK-NOT: arith.constant 20
+//
+// CHECK-CANONICALIZE-LABEL: func.func @gpu_launch_in_multi_predecessor_block
+// CHECK-CANONICALIZE:         %[[c10:.*]] = arith.constant 10 : i64
+// CHECK-CANONICALIZE:         cf.br ^[[BB2:bb[0-9]+]](%[[c10]] : i64)
+// CHECK-CANONICALIZE:       ^[[BB1:bb[0-9]+]]: // no predecessors
+// CHECK-CANONICALIZE:         %[[p:.*]] = ub.poison : i64
+// CHECK-CANONICALIZE:         cf.br ^[[BB2]](%[[p]] : i64)
+// CHECK-CANONICALIZE:       ^[[BB2]](%{{.*}}: i64):
+// CHECK-CANONICALIZE:         return %{{.*}} : i64
+func.func @gpu_launch_in_multi_predecessor_block() -> i64 {
+  %cond = arith.constant true
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  %c10 = arith.constant 10 : i64
+  cf.br ^bb3(%c10 : i64)
+^bb2:
+  %c20 = arith.constant 20 : i64
+  cf.br ^bb3(%c20 : i64)
+^bb3(%arg0: i64):
+  %c1 = arith.constant 1 : index
+  gpu.launch
+    blocks(%bx, %by, %bz) in (%gx = %c1, %gy = %c1, %gz = %c1)
+    threads(%tx, %ty, %tz) in (%bsx = %c1, %bsy = %c1, %bsz = %c1) {
+    %blk_x = gpu.block_id x
+    %thr_x = gpu.thread_id x
+    gpu.terminator
+  }
+  func.return %arg0 : i64
+}
+
+// -----
 
 // CHECK-LABEL: llvm_unreachable
 // CHECK-LABEL: @fn_with_llvm_unreachable



More information about the Mlir-commits mailing list