[Mlir-commits] [mlir] [mlir][RemoveDeadValues] Simplify branch op handling using ub.poison (PR #182711)
Fedor Nikolaev
llvmlistbot at llvm.org
Fri Mar 6 03:11:43 PST 2026
https://github.com/felichita updated https://github.com/llvm/llvm-project/pull/182711
>From f7b395cc20efef8428205df4a35504f6ea90b886 Mon Sep 17 00:00:00 2001
From: Fedor Nikolaev <fridrixnm at gmail.com>
Date: Sun, 22 Feb 2026 00:01:54 +0100
Subject: [PATCH] [mlir][RemoveDeadValues] Simplify branch op handling using
ub.poison
Replace the complex block argument removal logic in processBranchOp with
a simpler design based on ub.poison replacement, similar to the approach
used for region branch ops.
Now, dead successor operands are replaced with ub.poison and block
arguments are left intact, relying on the canonicalizer to remove them
once all incoming operands are poison.
Fixes #182263
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 133 +++++--------------
mlir/test/Dialect/SPIRV/IR/return-ops.mlir | 6 +-
mlir/test/Transforms/remove-dead-values.mlir | 80 ++++++++---
3 files changed, 97 insertions(+), 122 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 957621c16bf2b..f30678f8cd664 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -20,6 +20,8 @@
// terminator operands of region branch ops, and,
// (D) Removes simple and region branch ops that have all non-live results and
// don't affect memory in any way,
+// (E) Replaces dead operands of branch ops with `ub.poison`, relying on the
+// canonicalizer to remove the corresponding block arguments.
//
// iff
//
@@ -101,24 +103,11 @@ struct OperandsToCleanup {
bool replaceWithPoison = false;
};
-struct BlockArgsToCleanup {
- Block *b;
- BitVector nonLiveArgs;
-};
-
-struct SuccessorOperandsToCleanup {
- BranchOpInterface branch;
- unsigned successorIndex;
- BitVector nonLiveOperands;
-};
-
struct RDVFinalCleanupList {
SmallVector<Operation *> operations;
SmallVector<FunctionToCleanUp> functions;
SmallVector<OperandsToCleanup> operands;
SmallVector<ResultsToCleanup> results;
- SmallVector<BlockArgsToCleanup> blocks;
- SmallVector<SuccessorOperandsToCleanup> successorOperands;
};
// Some helper functions...
@@ -476,11 +465,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
///
/// Otherwise, iterate through each successor block of `branchOp`.
/// (1) For each successor block, gather all operands from all successors.
-/// (2) Fetch their associated liveness analysis data and collect for future
-/// removal.
-/// (3) Identify and collect the dead operands from the successor block
-/// as well as their corresponding arguments.
-
+/// (2) Determine which operands are dead using liveness analysis.
+/// (3) Replace dead successor operands with ub.poison instead of erasing them.
+/// Block arguments are left intact — the canonicalizer will remove them
+/// once it sees all incoming operands are poison.
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
@@ -490,7 +478,8 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
BitVector deadNonForwardedOperands =
markLives(branchOp->getOperands(), nonLiveSet, la).flip();
unsigned numSuccessors = branchOp->getNumSuccessors();
- for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+
+ for (unsigned succIdx : llvm::seq<unsigned>(0, numSuccessors)) {
SuccessorOperands successorOperands =
branchOp.getSuccessorOperands(succIdx);
// Remove all non-forwarded operands from the bit vector.
@@ -502,28 +491,20 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
return;
}
- for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
- Block *successorBlock = branchOp->getSuccessor(succIdx);
-
- // Do (1)
- SuccessorOperands successorOperands =
- branchOp.getSuccessorOperands(succIdx);
- SmallVector<Value> operandValues;
- for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
- ++operandIdx) {
- operandValues.push_back(successorOperands[operandIdx]);
+ // For each successor, find dead forwarded operands and
+ // schedule them for replacement with ub.poison.
+ BitVector opNonLive(branchOp->getNumOperands(), false);
+ for (unsigned succIdx : llvm::seq<unsigned>(0, numSuccessors)) {
+ for (OpOperand &opOperand :
+ branchOp.getSuccessorOperands(succIdx).getMutableForwardedOperands()) {
+ if (!hasLive(opOperand.get(), nonLiveSet, la))
+ opNonLive.set(opOperand.getOperandNumber());
}
-
- // Do (2)
- BitVector successorNonLive =
- markLives(operandValues, nonLiveSet, la).flip();
- collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
- successorNonLive);
-
- // Do (3)
- cl.blocks.push_back({successorBlock, successorNonLive});
- cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
}
+
+ if (opNonLive.any())
+ cl.operands.push_back({branchOp.getOperation(), opNonLive,
+ /*callee=*/nullptr, /*replaceWithPoison=*/true});
}
/// Create ub.poison ops for the given values. If a value has no uses, return
@@ -564,56 +545,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
TrackingListener listener;
IRRewriter rewriter(ctx, &listener);
- // 1. Blocks, We must remove the block arguments and successor operands before
- // deleting the operation, as they may reside in the region operation.
- LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
- for (auto &b : list.blocks) {
- // blocks that are accessed via multiple codepaths processed once
- if (b.b->getNumArguments() != b.nonLiveArgs.size())
- continue;
- LDBG_OS([&](raw_ostream &os) {
- os << "Erasing non-live arguments [";
- llvm::interleaveComma(b.nonLiveArgs.set_bits(), os);
- os << "] from block #" << b.b->computeBlockNumber() << " in region #"
- << b.b->getParent()->getRegionNumber() << " of operation "
- << OpWithFlags(b.b->getParent()->getParentOp(),
- OpPrintingFlags().skipRegions().printGenericOpForm());
- });
- // Note: Iterate from the end to make sure that that indices of not yet
- // processes arguments do not change.
- for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
- if (!b.nonLiveArgs[i])
- continue;
- b.b->getArgument(i).dropAllUses();
- b.b->eraseArgument(i);
- }
- }
-
- // 2. Successor Operands
- LDBG() << "Cleaning up " << list.successorOperands.size()
- << " successor operand lists";
- for (auto &op : list.successorOperands) {
- SuccessorOperands successorOperands =
- op.branch.getSuccessorOperands(op.successorIndex);
- // blocks that are accessed via multiple codepaths processed once
- if (successorOperands.size() != op.nonLiveOperands.size())
- continue;
- LDBG_OS([&](raw_ostream &os) {
- os << "Erasing non-live successor operands [";
- llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
- os << "] from successor " << op.successorIndex << " of branch: "
- << OpWithFlags(op.branch.getOperation(),
- OpPrintingFlags().skipRegions().printGenericOpForm());
- });
- // it iterates backwards because erase invalidates all successor indexes
- for (int i = successorOperands.size() - 1; i >= 0; --i) {
- if (!op.nonLiveOperands[i])
- continue;
- successorOperands.erase(i);
- }
- }
-
- // 3. Functions
+ // 1. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
// Record which function arguments were erased so we can shrink call-site
// argument segments for CallOpInterface operations (e.g. ops using
@@ -647,7 +579,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
- // 4. Operands
+ // 2. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperandsToCleanup &o : list.operands) {
// Handle call-specific cleanup only when we have a cached callee reference.
@@ -705,7 +637,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
}
}
- // 5. Results
+ // 3. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG_OS([&](raw_ostream &os) {
@@ -718,7 +650,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
}
- // 6. Operations
+ // 4. Operations
LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (Operation *op : list.operations) {
LDBG() << "Erasing operation: "
@@ -755,7 +687,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
rewriter.eraseOp(op);
}
- // 7. Remove all dead poison ops.
+ // 5. Remove all dead poison ops.
for (ub::PoisonOp poisonOp : listener.poisonOps) {
if (poisonOp.use_empty())
poisonOp.erase();
@@ -808,20 +740,17 @@ void RemoveDeadValues::runOnOperation() {
if (!canonicalize)
return;
- // Canonicalize all region branch ops.
- SmallVector<Operation *> opsToCanonicalize;
- module->walk([&](RegionBranchOpInterface regionBranchOp) {
- opsToCanonicalize.push_back(regionBranchOp.getOperation());
- });
- // Collect all canonicalization patterns for region branch ops.
+ // Canonicalize all region branch ops and branch ops.
RewritePatternSet owningPatterns(context);
DenseSet<RegisteredOperationName> populatedPatterns;
- for (Operation *op : opsToCanonicalize)
+ module->walk([&](Operation *op) {
+ if (!isa<RegionBranchOpInterface, BranchOpInterface>(op))
+ return;
if (std::optional<RegisteredOperationName> info = op->getRegisteredInfo())
if (populatedPatterns.insert(*info).second)
info->getCanonicalizationPatterns(owningPatterns, context);
- if (failed(applyOpPatternsGreedily(opsToCanonicalize,
- std::move(owningPatterns)))) {
+ });
+ if (failed(applyPatternsGreedily(module, std::move(owningPatterns)))) {
module->emitError("greedy pattern rewrite failed to converge");
signalPassFailure();
}
diff --git a/mlir/test/Dialect/SPIRV/IR/return-ops.mlir b/mlir/test/Dialect/SPIRV/IR/return-ops.mlir
index 2f945b24d24fd..b12cba5f7a074 100644
--- a/mlir/test/Dialect/SPIRV/IR/return-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/return-ops.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --remove-dead-values | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE
// Make sure that the return value op is considered as a return-like op and
// remains live.
@@ -8,6 +9,9 @@
// CHECK-NEXT: %[[BITCAST0:.*]] = spirv.Bitcast %[[ARG1]] : vector<2xi32> to vector<2xf32>
// CHECK-NEXT: %[[BITCAST1:.*]] = spirv.Bitcast %[[BITCAST0]] : vector<2xf32> to vector<2xi32>
// CHECK-NEXT: spirv.ReturnValue %[[BITCAST1]] : vector<2xi32>
+// CHECK-CANONICALIZE-LABEL: @preserve_return_value
+// CHECK-CANONICALIZE-SAME: (%[[ARG0:.*]]: vector<2xi32>, %[[ARG1:.*]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-CANONICALIZE-NEXT: spirv.ReturnValue %[[ARG1]] : vector<2xi32>
spirv.func @preserve_return_value(%arg0: vector<2xi32>, %arg1: vector<2xi32>) -> vector<2xi32> "None" {
%0 = spirv.Bitcast %arg0 : vector<2xi32> to vector<2xf32>
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 19bc6b2fddd66..22e4d66ef0ea5 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -35,20 +35,24 @@ module @named_module_acceptable {
func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) {
%non_live = arith.constant 0 : i32
// CHECK-NOT: arith.constant
+ // CHECK-CANONICALIZE-NOT: arith.constant
cf.br ^bb1(%non_live : i32)
- // CHECK: cf.br ^[[BB1:bb[0-9]+]]
+ // CHECK: cf.br ^[[BB1:bb[0-9]+]](%{{.*}} : i32)
+ // CHECK-CANONICALIZE: cf.br ^[[BB1:bb[0-9]+]]
^bb1(%non_live_1 : i32):
- // CHECK: ^[[BB1]]:
+ // CHECK: ^[[BB1]](%{{.*}}: i32):
+ // CHECK-CANONICALIZE: ^[[BB1]]:
%non_live_5 = arith.constant 1 : i32
cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
- // CHECK: cf.br ^[[BB3:bb[0-9]+]]
- // CHECK-NOT: i32
+ // CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}}, %{{.*}} : i32, i32)
+ // CHECK-CANONICALIZE: cf.cond_br %arg0, ^[[BB1]], ^[[BB2:bb[0-9]+]]
^bb3(%non_live_2 : i32, %non_live_6 : i32):
- // CHECK: ^[[BB3]]:
+ // CHECK: ^[[BB3]](%{{.*}}: i32, %{{.*}}: i32):
cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
- // CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
+ // CHECK: cf.cond_br %arg0, ^[[BB1]](%{{.*}} : i32), ^[[BB4:bb[0-9]+]](%{{.*}} : i32)
^bb4(%non_live_4 : i32):
- // CHECK: ^[[BB4]]:
+ // CHECK: ^[[BB4]](%{{.*}}: i32):
+ // CHECK-CANONICALIZE: ^[[BB2]]:
return
}
@@ -345,9 +349,9 @@ func.func private @identity(%arg1 : i32) -> (i32) {
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CHECK-CANONICALIZE: %[[c10:.*]] = arith.constant 10
// CHECK-CANONICALIZE-NEXT: scf.index_switch %[[arg0]]
// CHECK-CANONICALIZE-NEXT: case 1 {
-// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10
// CHECK-CANONICALIZE-NEXT: memref.store %[[c10]], %[[arg1]][]
// CHECK-CANONICALIZE: scf.yield
// CHECK-CANONICALIZE-NEXT: }
@@ -476,6 +480,47 @@ func.func @kernel(%arg0: memref<18xf32>) {
// -----
+// Test that RemoveDeadValues does not crash when gpu.launch appears in a block
+// with multiple predecessors. The dead branch operand (%c20) must be replaced
+// with ub.poison, gpu.launch and its grid/block size operands must be
+// preserved, and the live block argument must remain intact.
+//
+// CHECK-LABEL: func.func @gpu_launch_in_multi_predecessor_block
+// CHECK: arith.constant true
+// CHECK: cf.cond_br
+// CHECK: arith.constant 10
+// CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}} : i64)
+// CHECK: ub.poison
+// CHECK: cf.br ^[[BB3]](%{{.*}} : i64)
+// CHECK: ^[[BB3]](%{{.*}}: i64):
+// CHECK: return
+// CHECK-NOT: arith.constant 20
+//
+// CHECK-CANONICALIZE-LABEL: func.func @gpu_launch_in_multi_predecessor_block
+// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10 : i64
+// CHECK-CANONICALIZE-NEXT: return %[[c10]]
+func.func @gpu_launch_in_multi_predecessor_block() -> i64 {
+ %cond = arith.constant true
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ %c10 = arith.constant 10 : i64
+ cf.br ^bb3(%c10 : i64)
+^bb2:
+ %c20 = arith.constant 20 : i64
+ cf.br ^bb3(%c20 : i64)
+^bb3(%arg0: i64):
+ %c1 = arith.constant 1 : index
+ gpu.launch
+ blocks(%bx, %by, %bz) in (%gx = %c1, %gy = %c1, %gz = %c1)
+ threads(%tx, %ty, %tz) in (%bsx = %c1, %bsy = %c1, %bsz = %c1) {
+ %blk_x = gpu.block_id x
+ %thr_x = gpu.thread_id x
+ gpu.terminator
+ }
+ func.return %arg0 : i64
+}
+
+// -----
// CHECK-LABEL: llvm_unreachable
// CHECK-LABEL: @fn_with_llvm_unreachable
@@ -768,6 +813,7 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() {
// CHECK: return %[[while]]#0
// CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args()
+// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32
// CHECK-CANONICALIZE: %[[c5:.*]] = arith.constant 5 : i32
// CHECK-CANONICALIZE: %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 {
// CHECK-CANONICALIZE: vector.print %[[arg0]]
@@ -775,7 +821,6 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() {
// CHECK-CANONICALIZE: scf.condition(%[[cmpi]]) %[[arg0]]
// CHECK-CANONICALIZE: } do {
// CHECK-CANONICALIZE: ^bb0(%[[arg1:.*]]: i32):
-// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32
// CHECK-CANONICALIZE: scf.yield %[[p0]]
// CHECK-CANONICALIZE: }
// CHECK-CANONICALIZE: return %[[while]]
@@ -799,7 +844,13 @@ func.func @scf_while_dead_iter_args() -> i32 {
// -----
-// CHECK-LABEL: func.func @replace_dead_operation_results_with_poison
+// Check that this prevents a crash in the canonicalization phase which
+// happens after the dead value removal phase. Also check that only used
+// results of an erased op are replaced with ub.poison.
+
+// CHECK-CANONICALIZE-LABEL: func.func @replace_dead_operation_results_with_poison
+// CHECK-CANONICALIZE-NEXT: %[[p:.*]] = ub.poison : vector<1xindex>
+// CHECK-CANONICALIZE-NEXT: return %[[p]]
func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> vector<1xindex> {
%1 = scf.while (%arg0 = %0) : (vector<1xindex>) -> vector<1xindex> {
%cond = arith.constant true
@@ -809,15 +860,6 @@ func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> ve
scf.yield %arg0 : vector<1xindex>
}
%2 = scf.while (%arg0 = %1) : (vector<1xindex>) -> vector<1xindex> {
- // Check that the binary value in condition is replaced with poison, and
- // the condition itself is well-formed IR. This prevents a crash in the
- // canonicalization phase which happens after the dead value removal phase.
- // Also check that only used results of an erased op are replaced with ub.poison.
- // CHECK-CANONICALIZE: %[[COND:.*]] = ub.poison : i1
- // CHECK-CANONICALIZE-NEXT: %[[NEXT:.*]] = ub.poison : vector<1xindex>
- // CHECK-CANONICALIZE-NEXT: scf.condition(%[[COND]]) %[[NEXT]]
- // CHECK-CANONICALIZE-NOT: ub.poison : i32
- // CHECK-CANONICALIZE-NOT: "test.three"
%cond, %unused, %next = "test.three"(%1) : (vector<1xindex>) -> (i1, i32, vector<1xindex>)
scf.condition(%cond) %next : vector<1xindex>
} do {
More information about the Mlir-commits
mailing list