[Mlir-commits] [mlir] [MLIR] Fix RemoveDeadValues handling of unreachable code (PR #153973)
Mehdi Amini
llvmlistbot at llvm.org
Sun Aug 17 03:59:51 PDT 2025
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/153973
>From 0f0602e02ce153cfe36c6187f58790675659cf13 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 17 Aug 2025 03:53:38 -0700
Subject: [PATCH] [MLIR] Fix RemoveDeadValues handling of unreachable code
The code was "conservatively" considering that the absence of liveness
information meant that we weren't sure if a value was dead. However
the dataflow framework will skip visiting these values when its already
knows that a block is dynamically unreachable.
This leds to a crash in the provided test case since the producer operation
in a reachable block is actually deleted (its liveness is correctly set to
"dead") but the user of the operation was "conservatively" preserved and left
with nullptr operands.
Fixes #153906
---
.../Analysis/DataFlow/LivenessAnalysis.cpp | 29 +++++++++++-
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 46 +++++++++++++++++--
mlir/lib/Transforms/RemoveDeadValues.cpp | 46 ++++++++++++++++---
.../DataFlow/test-liveness-analysis.mlir | 20 ++++++++
mlir/test/Transforms/remove-dead-values.mlir | 21 +++++++++
.../DataFlow/TestLivenessAnalysis.cpp | 1 -
6 files changed, 150 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 509f5202be08d..761098778a371 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName()
+ << " check on dead code now:";
+ // The framework doesn't visit operations in dead blocks, so we need to
+ // explicitly mark them as dead.
+ op->walk([&](Operation *op) {
+ if (op->getNumResults() == 0)
+ return;
+ for (auto result : llvm::enumerate(op->getResults())) {
+ if (getLiveness(result.value()))
+ continue;
+ LDBG() << "Result: " << result.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info, mark dead";
+ solver.getOrCreateState<Liveness>(result.value());
+ }
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region) {
+ for (auto blockArg : llvm::enumerate(block.getArguments())) {
+ if (getLiveness(blockArg.value()))
+ continue;
+ LDBG() << "Block argument: " << blockArg.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info, mark dead";
+ solver.getOrCreateState<Liveness>(blockArg.value());
+ }
+ }
+ }
+ });
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index e625f626d12fd..13a3e1480c836 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -19,12 +19,15 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
+#define DEBUG_TYPE "dataflow"
+
//===----------------------------------------------------------------------===//
// AbstractSparseLattice
//===----------------------------------------------------------------------===//
@@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+ LDBG() << "Initializing recursively for operation: " << op->getName();
+
// Initialize the analysis by visiting every owner of an SSA value (all
// operations and blocks).
- if (failed(visitOperation(op)))
+ if (failed(visitOperation(op))) {
+ LDBG() << "Failed to visit operation: " << op->getName();
return failure();
+ }
for (Region ®ion : op->getRegions()) {
+ LDBG() << "Processing region with " << region.getBlocks().size()
+ << " blocks";
for (Block &block : region) {
+ LDBG() << "Processing block with " << block.getNumArguments()
+ << " arguments";
getOrCreate<Executable>(getProgramPointBefore(&block))
->blockContentSubscribe(this);
visitBlock(&block);
- for (Operation &op : block)
- if (failed(initializeRecursively(&op)))
+ for (Operation &op : block) {
+ LDBG() << "Recursively initializing nested operation: " << op.getName();
+ if (failed(initializeRecursively(&op))) {
+ LDBG() << "Failed to initialize nested operation: " << op.getName();
return failure();
+ }
+ }
}
}
+ LDBG() << "Successfully completed recursive initialization for operation: "
+ << op->getName();
return success();
}
@@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+ LDBG() << "Visiting operation: " << op->getName() << " with "
+ << op->getNumOperands() << " operands and " << op->getNumResults()
+ << " results";
+
// If we're in a dead block, bail out.
if (op->getBlock() != nullptr &&
- !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+ ->isLive()) {
+ LDBG() << "Operation is in dead block, bailing out";
return success();
+ }
+ LDBG() << "Creating lattice elements for " << op->getNumOperands()
+ << " operands and " << op->getNumResults() << " results";
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// Block arguments of region branch operations flow back into the operands
// of the parent op
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchOpInterface operation";
visitRegionSuccessors(branch, operandLattices);
return success();
}
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+ LDBG() << "Processing BranchOpInterface operation with "
+ << op->getNumSuccessors() << " successors";
+
// Block arguments of successor blocks flow back into our operands.
// We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
+ LDBG() << "Processing CallOpInterface operation";
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
// Not all operands of a call op forward to arguments. Such operands are
@@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// of this op itself and the operands of the terminators of the regions of
// this op.
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchTerminatorOpInterface operation";
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
visitRegionSuccessorsFromTerminator(terminator, branch);
return success();
@@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
}
if (op->hasTrait<OpTrait::ReturnLike>()) {
+ LDBG() << "Processing ReturnLike operation";
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+ LDBG() << "Callable parent found, visiting callable operation";
return visitCallableOperation(op, callable, operandLattices);
+ }
}
+ LDBG() << "Using default visitOperationImpl for operation: " << op->getName();
return visitOperationImpl(op, operandLattices, resultLattices);
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 4ccb83f3ad298..02dad69e49614 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
- LDBG()
- << "Simple op is not memory effect free or has live results, skipping: "
- << *op;
+ LDBG() << "Simple op is not memory effect free or has live results, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return;
}
LDBG()
<< "Simple op has all dead results and is memory effect free, scheduling "
"for removal: "
- << *op;
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
@@ -728,19 +727,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+ LDBG() << "Starting cleanup of dead values...";
+
// 1. Operations
+ LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (auto &op : list.operations) {
+ LDBG() << "Erasing operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
op->dropAllUses();
op->erase();
}
// 2. Values
+ LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
+ LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}
// 3. Functions
+ LDBG() << "Cleaning up " << list.functions.size() << " functions";
for (auto &f : list.functions) {
+ LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
+ LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
+ LDBG() << " Erasing " << f.nonLiveRets.count()
+ << " non-live return values";
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
@@ -749,44 +760,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
// 4. Operands
+ LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0)
+ if (o.op->getNumOperands() > 0) {
+ LDBG() << "Erasing " << o.nonLive.count()
+ << " non-live operands from operation: "
+ << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
o.op->eraseOperands(o.nonLive);
+ }
}
// 5. Results
+ LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
+ LDBG() << "Erasing " << r.nonLive.count()
+ << " non-live results from operation: "
+ << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}
// 6. Blocks
+ LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
+ LDBG() << "Erasing " << b.nonLiveArgs.count()
+ << " non-live arguments from block: " << b.b;
// it iterates backwards because erase invalidates all successor indexes
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
+ LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}
// 7. Successor Operands
+ LDBG() << "Cleaning up " << list.successorOperands.size()
+ << " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
+ LDBG() << "Erasing " << op.nonLiveOperands.count()
+ << " non-live successor operands from successor "
+ << op.successorIndex << " of branch: "
+ << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
+ LDBG() << " Erasing successor operand " << i << ": "
+ << successorOperands[i];
successorOperands.erase(i);
}
}
+
+ LDBG() << "Finished cleanup of dead values";
}
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index a89a0f4084e99..4fece5520c9d3 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -283,3 +283,23 @@ func.func @test_10_negative() -> (i32) {
%0:2 = func.call @private_1() : () -> (i32, i32)
return %0#0 : i32
}
+
+// -----
+
+// Test that we correctly set a liveness value for operations
+// in dead block. These won't be visited by the dataflow framework
+// so the analysis need to explicit manage them.
+// CHECK-LABEL: test_tag: dead_block_cmpi:
+// CHECK-NEXT: operand #0: not live
+// CHECK-NEXT: operand #1: not live
+// CHECK-NEXT: result #0: not live
+func.func @dead_block() {
+ %false = arith.constant false
+ %zero = arith.constant 0 : i64
+ cf.cond_br %false, ^bb1, ^bb4
+ ^bb1:
+ %3 = arith.cmpi eq, %zero, %zero {tag = "dead_block_cmpi"} : i64
+ cf.br ^bb1
+ ^bb4:
+ return
+}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9ded6a30d9c95..0f8d757086e87 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -571,3 +571,24 @@ module @return_void_with_unused_argument {
}
}
+// -----
+
+// CHECK-LABEL: module @dynamically_unreachable
+module @dynamically_unreachable {
+ func.func @dynamically_unreachable() {
+ // This value is used by an operation in a dynamically unreachable block.
+ %zero = arith.constant 0 : i64
+
+ // Dataflow analysis knows from the constant condition that
+ // ^bb1 is unreachable
+ %false = arith.constant false
+ cf.cond_br %false, ^bb1, ^bb4
+ ^bb1:
+ // This unreachable operation should be removed.
+ // CHECK-NOT: arith.cmpi
+ %3 = arith.cmpi eq, %zero, %zero : i64
+ cf.br ^bb1
+ ^bb4:
+ return
+ }
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
index 43005e22584c2..8e2f03b644e49 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
@@ -33,7 +33,6 @@ struct TestLivenessAnalysisPass
void runOnOperation() override {
auto &livenessAnalysis = getAnalysis<RunLivenessAnalysis>();
-
Operation *op = getOperation();
raw_ostream &os = llvm::outs();
More information about the Mlir-commits
mailing list