[llvm-branch-commits] [mlir] [mlir][Transforms] `remove-dead-values`: Rely on canonicalizer for region simplification (PR #173505)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Dec 28 11:56:04 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173505
>From 73c080f0363dceb7d99c73bc7db4ec3218e6e1f8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 24 Dec 2025 13:26:00 +0000
Subject: [PATCH] tmp commit
simple test working
draft: do not erase IR, just replace uses
---
mlir/include/mlir/Transforms/Passes.h | 1 +
mlir/include/mlir/Transforms/Passes.td | 10 +
mlir/lib/Transforms/RemoveDeadValues.cpp | 518 ++++++++-----------
mlir/test/Transforms/remove-dead-values.mlir | 155 ++++--
4 files changed, 325 insertions(+), 359 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..9983944d374c5 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -42,6 +42,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_MEM2REG
#define GEN_PASS_DECL_PRINTIRPASS
#define GEN_PASS_DECL_PRINTOPSTATS
+#define GEN_PASS_DECL_REMOVEDEADVALUES
#define GEN_PASS_DECL_SCCP
#define GEN_PASS_DECL_SROA
#define GEN_PASS_DECL_STRIPDEBUGINFO
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 55addfdb693e4..fc2d60d198cd6 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -246,7 +246,17 @@ def RemoveDeadValues : Pass<"remove-dead-values"> {
do = square_and_double_of_y(5)
print(do)
```
+
+ Note: If `canonicalize` is set to "false", this pass does not remove any
+ block arguments / op results from ops that implement the
+ RegionBranchOpInterface. Instead, it just sets dead operands to
+ "ub.poison".
}];
+
+ let options = [
+ Option<"canonicalize", "canonicalize", "bool", /*default=*/"true",
+ "Canonicalize region branch ops">,
+ ];
let constructor = "mlir::createRemoveDeadValuesPass()";
let dependentDialects = ["ub::UBDialect"];
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 62ce5e0bbb77e..323d7f4cb1616 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -94,8 +94,11 @@ struct ResultsToCleanup {
struct OperandsToCleanup {
Operation *op;
BitVector nonLive;
- Operation *callee =
- nullptr; // Optional: For CallOpInterface ops, stores the callee function
+ // Optional: For CallOpInterface ops, stores the callee function.
+ Operation *callee = nullptr;
+ // Determines whether the operand should be replaced with a ub.poison result
+ // or erased entirely.
+ bool replaceWithPoison = false;
};
struct BlockArgsToCleanup {
@@ -199,9 +202,9 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
}
}
-/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
-/// is 1.
-static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
+/// Erase the i-th result of `op` iff toErase[i] is 1.
+static void eraseResults(RewriterBase &rewriter, Operation *op,
+ BitVector toErase) {
assert(op->getNumResults() == toErase.size() &&
"expected the number of results in `op` and the size of `toErase` to "
"be the same");
@@ -210,7 +213,6 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
for (OpResult result : op->getResults())
if (!toErase[result.getResultNumber()])
newResultTypes.push_back(result.getType());
- IRRewriter rewriter(op);
rewriter.setInsertionPointAfter(op);
OperationState state(op->getLoc(), op->getName().getStringRef(),
op->getOperands(), newResultTypes, op->getAttrs());
@@ -226,14 +228,12 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
unsigned indexOfNextNewCallOpResultToReplace = 0;
for (auto [index, result] : llvm::enumerate(op->getResults())) {
assert(result && "expected result to be non-null");
- if (toErase[index]) {
- result.dropAllUses();
- } else {
+ if (!toErase[index]) {
result.replaceAllUsesWith(
newOp->getResult(indexOfNextNewCallOpResultToReplace++));
}
}
- op->erase();
+ rewriter.eraseOp(op);
}
/// Convert a list of `Operand`s to a list of `OpOperand`s.
@@ -404,30 +404,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
///
/// Scenario 1: If the operation has no memory effects and none of its results
/// are live:
-/// (1') Enqueue all its uses for deletion.
-/// (2') Enqueue the branch itself for deletion.
+/// 1.1. Enqueue all its uses for deletion.
+/// 1.2. Enqueue the branch itself for deletion.
///
/// Scenario 2: Otherwise:
-/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
-/// results or arguments).
-/// (2) Process each of its regions.
-/// (3) Collect the uses of its unnecessary results (results forwarded from
-/// unnecessary operands
-/// or terminator operands).
-/// (4) Add these results to the deletion list.
-///
-/// Processing a region includes:
-/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
-/// from unnecessary operands
-/// or terminator operands).
-/// (b) Collecting these unnecessary arguments.
-/// (c) Collecting its unnecessary terminator operands (terminator operands
-/// forwarded to unnecessary results
-/// or arguments).
+/// 2.1. Collect block arguments and op results that we would like to keep,
+/// based on their liveness.
+/// 2.2. Find all operands that are forwarded to only dead region successor
+/// inputs. I.e., forwarded to block arguments / op results that we do
+/// not want to keep.
+/// 2.3. Enqueue all such operands for replacement with ub.poison.
///
-/// Value Flow Note: In this operation, values flow as follows:
-/// - From operands and terminator operands (successor operands)
-/// - To arguments and results (successor inputs).
+/// Note: In scenario 2, block arguments and op results are not removed.
+/// However, the IR is simplified such that canonicalization patterns can
+/// remove them later.
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
@@ -441,284 +431,103 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
// It could never be live because of this op but its liveness could have been
// attributed to something else.
- // Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
- // Mark live results of `regionBranchOp` in `liveResults`.
- auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
- };
-
- // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
- auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
- liveArgs[®ion] = regionLiveArgs;
+ // Compute values that are alive.
+ DenseSet<Value> valuesToKeep;
+ for (Value result : regionBranchOp->getResults()) {
+ if (hasLive(result, nonLiveSet, la))
+ valuesToKeep.insert(result);
+ }
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
+ for (Value arg : region.front().getArguments()) {
+ if (hasLive(arg, nonLiveSet, la))
+ valuesToKeep.insert(arg);
}
- };
+ }
- // Return the successors of `region` if the latter is not null. Else return
- // the successors of `regionBranchOp`.
- auto getSuccessors = [&](RegionBranchPoint point) {
+ // Mapping from operands to forwarded successor inputs. An operand can be
+ // forwarded to multiple successors.
+ DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
+ auto helper = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
- return successors;
- };
-
- // Return the operands of `terminator` that are forwarded to `successor` if
- // the former is not null. Else return the operands of `regionBranchOp`
- // forwarded to `successor`.
- auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
- Operation *terminator = nullptr) {
- OperandRange operands =
- terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(successor)
- : regionBranchOp.getEntrySuccessorOperands(successor);
- SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
- return opOperands;
- };
-
- // Mark the non-forwarded operands of `regionBranchOp` in
- // `nonForwardedOperands`.
- auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
- nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- for (OpOperand *opOperand : getForwardedOpOperands(successor))
- nonForwardedOperands.reset(opOperand->getOperandNumber());
+ for (const RegionSuccessor &successor : successors) {
+ // Handle branch from point --> successor.
+ ValueRange argsOrResults = successor.getSuccessorInputs();
+ OperandRange operands =
+ point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor)
+ : cast<RegionBranchTerminatorOpInterface>(
+ point.getTerminatorPredecessorOrNull())
+ .getSuccessorOperands(successor);
+ assert(
+ argsOrResults.size() == operands.size() &&
+ "expected the same number of successor inputs as forwarded operands");
+
+ for (auto [opOperand, input] :
+ llvm::zip_equal(operandsToOpOperands(operands), argsOrResults)) {
+ operandToSuccessorInputs[opOperand].push_back(input);
+ }
}
};
- // Mark the non-forwarded terminator operands of the various regions of
- // `regionBranchOp` in `nonForwardedRets`.
- auto markNonForwardedReturnValues =
- [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- // TODO: this isn't correct in face of multiple terminators.
- Operation *terminator = region.front().getTerminator();
- nonForwardedRets[terminator] =
- BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- for (OpOperand *opOperand :
- getForwardedOpOperands(successor, terminator))
- nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
- }
- }
- };
-
- // Update `valuesToKeep` (which is expected to correspond to operands or
- // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
- // `region`. When `valuesToKeep` correspond to operands, `region` is null.
- // Else, `region` is the parent region of the terminator.
- auto updateOperandsOrTerminatorOperandsToKeep =
- [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
- DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
- Operation *terminator =
- region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
- terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
-
- for (const RegionSuccessor &successor : getSuccessors(point)) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- size_t operandNum = opOperand->getOperandNumber();
- bool updateBasedOn =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
- }
- }
- };
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
- // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
- // value is modified, else, false.
- auto recomputeResultsAndArgsToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
- bool &resultsOrArgsToKeepChanged) {
- resultsOrArgsToKeepChanged = false;
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- operandsToKeep[opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
-
- // Recompute `resultsToKeep` and `argsToKeep` based on
- // `terminatorOperandsToKeep`.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- terminatorOperandsToKeep[region.back().getTerminator()]
- [opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
- }
- };
-
- // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
- // `operandsToKeep`, and `terminatorOperandsToKeep`.
- auto markValuesToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
- bool resultsOrArgsToKeepChanged = true;
- // We keep updating and recomputing the values until we reach a point
- // where they stop changing.
- while (resultsOrArgsToKeepChanged) {
- // Update the operands that need to be kept.
- updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
- resultsToKeep, argsToKeep);
-
- // Update the terminator operands that need to be kept.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- updateOperandsOrTerminatorOperandsToKeep(
- terminatorOperandsToKeep[region.back().getTerminator()],
- resultsToKeep, argsToKeep, ®ion);
- }
-
- // Recompute the results and arguments that need to be kept.
- recomputeResultsAndArgsToKeep(
- resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
- }
- };
-
- // Scenario 2.
- // At this point, we know that every non-forwarded operand of `regionBranchOp`
- // is live.
-
- // Stores the results of `regionBranchOp` that we want to keep.
- BitVector resultsToKeep;
- // Stores the mapping from regions of `regionBranchOp` to their arguments that
- // we want to keep.
- DenseMap<Region *, BitVector> argsToKeep;
- // Stores the operands of `regionBranchOp` that we want to keep.
- BitVector operandsToKeep;
- // Stores the mapping from region terminators in `regionBranchOp` to their
- // operands that we want to keep.
- DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
- // Initializing the above variables...
-
- // The live results of `regionBranchOp` definitely need to be kept.
- markLiveResults(resultsToKeep);
- // Similarly, the live arguments of the regions in `regionBranchOp` definitely
- // need to be kept.
- markLiveArgs(argsToKeep);
- // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
- // A live forwarded operand can be removed but no non-forwarded operand can be
- // removed since it "controls" the flow of data in this control flow op.
- markNonForwardedOperands(operandsToKeep);
- // Similarly, the non-forwarded terminator operands of the regions in
- // `regionBranchOp` definitely need to be kept.
- markNonForwardedReturnValues(terminatorOperandsToKeep);
-
- // Mark the values (results, arguments, operands, and terminator operands)
- // that we want to keep.
- markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep);
-
- // Do (1).
- cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
-
- // Do (2.a) and (2.b).
+ // Example:
+ //
+ // %0 = scf.while : () -> i32 {
+ // scf.condition(...) %forwarded_value : i32
+ // } do {
+ // ^bb0(%arg0: i32):
+ // scf.yield
+ // }
+ // // No uses of %0.
+ //
+ // In the above example, %forwarded_value is forwarded to %arg0 and %0. Both
+ // %arg0 and %0 are dead, so %forwarded_value can be replaced with a
+ // ub.poison result.
+ //
+ // operandToSuccessorInputs[%forwarded_value] = {%arg0, %0}
+ //
+ helper(RegionBranchPoint::parent());
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- BitVector argsToRemove = argsToKeep[®ion].flip();
- cl.blocks.push_back({®ion.front(), argsToRemove});
- collectNonLiveValues(nonLiveSet, region.front().getArguments(),
- argsToRemove);
+ helper(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>(
+ region.front().getTerminator())));
}
- // Do (2.c).
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
+ DenseMap<Operation *, BitVector> deadOperandsPerOp;
+ for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+ // If one of the successor inputs is live, the respective operand must be
+ // kept.
+ bool anyAlive = llvm::any_of(successorInputs, [&](Value input) {
+ return valuesToKeep.contains(input);
+ });
+ if (anyAlive)
continue;
- Operation *terminator = region.front().getTerminator();
- cl.operands.push_back(
- {terminator, terminatorOperandsToKeep[terminator].flip()});
+
+ // All successor inputs are dead: ub.poison can be passed as operand.
+ // Create an entry in `deadOperandsPerOp` (initialized to "false", i.e.,
+ // no "dead" op operands) if it's the first time that we are seeing an op
+ // operand for this op. Otherwise, just take the existing bit vector from
+ // the map.
+ BitVector &deadOperands =
+ deadOperandsPerOp
+ .try_emplace(opOperand->getOwner(),
+ opOperand->getOwner()->getNumOperands(), false)
+ .first->second;
+ deadOperands.set(opOperand->getOperandNumber());
}
- // Do (3) and (4).
- BitVector resultsToRemove = resultsToKeep.flip();
- collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
- resultsToRemove);
- cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
+ for (auto [op, deadOperands] : deadOperandsPerOp) {
+ cl.operands.push_back(
+ {op, deadOperands, nullptr, /*replaceWithPoison=*/true});
+ }
}
/// Steps to process a `BranchOpInterface` operation:
@@ -778,11 +587,44 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
}
}
+/// Create ub.poison ops for the given values. If a value has no uses, return
+/// an "empty" value.
+static SmallVector<Value> createPoisonedValues(OpBuilder &b,
+ ValueRange values) {
+ return llvm::map_to_vector(values, [&](Value value) {
+ if (value.use_empty())
+ return Value();
+ return ub::PoisonOp::create(b, value.getLoc(), value.getType()).getResult();
+ });
+}
+
+namespace {
+/// A listener that keeps track of ub.poison ops.
+struct TrackingListener : public RewriterBase::Listener {
+ void notifyOperationErased(Operation *op) override {
+ if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+ poisonOps.erase(poisonOp);
+ }
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+ poisonOps.insert(poisonOp);
+ }
+ DenseSet<ub::PoisonOp> poisonOps;
+};
+} // namespace
+
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
-static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
LDBG() << "Starting cleanup of dead values...";
+ // New ub.poison ops may be inserted during cleanup. Some of these ops may no
+ // longer be needed after the cleanup. A tracking listener keeps track of all
+ // new ub.poison ops, so that they can be removed again after the cleanup.
+ TrackingListener listener;
+ IRRewriter rewriter(ctx, &listener);
+
// 1. Blocks, We must remove the block arguments and successor operands before
// deleting the operation, as they may reside in the region operation.
LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
@@ -800,10 +642,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
});
// Note: Iterate from the end to make sure that that indices of not yet
// processes arguments do not change.
+ rewriter.setInsertionPointToStart(b.b);
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
- b.b->getArgument(i).dropAllUses();
+ b.b->getArgument(i).replaceAllUsesWith(
+ createPoisonedValues(rewriter, b.b->getArgument(i)).front());
b.b->eraseArgument(i);
}
}
@@ -832,22 +676,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
}
- // 3. Operations
- LDBG() << "Cleaning up " << list.operations.size() << " operations";
- for (Operation *op : list.operations) {
- LDBG() << "Erasing operation: "
- << OpWithFlags(op,
- OpPrintingFlags().skipRegions().printGenericOpForm());
- if (op->hasTrait<OpTrait::IsTerminator>()) {
- // When erasing a terminator, insert an unreachable op in its place.
- OpBuilder b(op);
- ub::UnreachableOp::create(b, op->getLoc());
- }
- op->dropAllUses();
- op->erase();
- }
-
- // 4. Functions
+ // 3. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
// Record which function arguments were erased so we can shrink call-site
// argument segments for CallOpInterface operations (e.g. ops using
@@ -864,12 +693,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
os << "]";
});
- // Drop all uses of the dead arguments.
- for (auto deadIdx : f.nonLiveArgs.set_bits())
- f.funcOp.getArgument(deadIdx).dropAllUses();
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
+ bool hasBody = !f.funcOp.getFunctionBody().empty();
+ if (hasBody) {
+ rewriter.setInsertionPointToStart(&f.funcOp.getFunctionBody().front());
+ for (auto deadIdx : f.nonLiveArgs.set_bits()) {
+ f.funcOp.getArgument(deadIdx).replaceAllUsesWith(
+ createPoisonedValues(rewriter, f.funcOp.getArgument(deadIdx))
+ .front());
+ }
+ }
if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
// Record only if we actually erased something.
if (f.nonLiveArgs.any())
@@ -878,7 +713,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
- // 5. Operands
+ // 4. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperandsToCleanup &o : list.operands) {
// Handle call-specific cleanup only when we have a cached callee reference.
@@ -923,11 +758,20 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
<< OpWithFlags(o.op,
OpPrintingFlags().skipRegions().printGenericOpForm());
});
- o.op->eraseOperands(o.nonLive);
+ if (o.replaceWithPoison) {
+ rewriter.setInsertionPoint(o.op);
+ for (auto deadIdx : o.nonLive.set_bits()) {
+ o.op->setOperand(
+ deadIdx, createPoisonedValues(rewriter, o.op->getOperand(deadIdx))
+ .front());
+ }
+ } else {
+ o.op->eraseOperands(o.nonLive);
+ }
}
}
- // 6. Results
+ // 5. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG_OS([&](raw_ostream &os) {
@@ -937,8 +781,34 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
<< OpWithFlags(r.op,
OpPrintingFlags().skipRegions().printGenericOpForm());
});
- dropUsesAndEraseResults(r.op, r.nonLive);
+ rewriter.setInsertionPoint(r.op);
+ for (auto deadIdx : r.nonLive.set_bits()) {
+ r.op->getResult(deadIdx).replaceAllUsesWith(
+ createPoisonedValues(rewriter, r.op->getResult(deadIdx)).front());
+ }
+ eraseResults(rewriter, r.op, r.nonLive);
+ }
+
+ // 6. Operations
+ LDBG() << "Cleaning up " << list.operations.size() << " operations";
+ for (Operation *op : list.operations) {
+ LDBG() << "Erasing operation: "
+ << OpWithFlags(op,
+ OpPrintingFlags().skipRegions().printGenericOpForm());
+ rewriter.setInsertionPoint(op);
+ if (op->hasTrait<OpTrait::IsTerminator>()) {
+ // When erasing a terminator, insert an unreachable op in its place.
+ ub::UnreachableOp::create(rewriter, op->getLoc());
+ }
+ rewriter.replaceOp(op, createPoisonedValues(rewriter, op->getResults()));
}
+
+ // 7. Remove all dead poison ops.
+ for (ub::PoisonOp poisonOp : listener.poisonOps) {
+ if (poisonOp.use_empty())
+ poisonOp.erase();
+ }
+
LDBG() << "Finished cleanup of dead values";
}
@@ -977,7 +847,27 @@ void RemoveDeadValues::runOnOperation() {
}
});
- cleanUpDeadVals(finalCleanupList);
+ MLIRContext *context = module->getContext();
+ cleanUpDeadVals(context, finalCleanupList);
+
+ if (!canonicalize)
+ return;
+
+ // Canonicalize all region branch ops.
+ SmallVector<Operation *> opsToCanonicalize;
+ module->walk([&](RegionBranchOpInterface regionBranchOp) {
+ opsToCanonicalize.push_back(regionBranchOp.getOperation());
+ });
+ RewritePatternSet owningPatterns(context);
+ for (auto *dialect : context->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(owningPatterns);
+ for (RegisteredOperationName op : context->getRegisteredOperations())
+ op.getCanonicalizationPatterns(owningPatterns, context);
+ if (failed(applyOpPatternsGreedily(opsToCanonicalize,
+ std::move(owningPatterns)))) {
+ module->emitError("greedy pattern rewrite failed to converge");
+ signalPassFailure();
+ }
}
std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 71306676d48e9..b9a883dbd524e 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE
// The IR is updated regardless of memref.global private constant
//
@@ -55,19 +56,20 @@ func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0:
// Checking that iter_args are properly handled
//
+// CHECK-CANONICALIZE-LABEL: func @cleanable_loop_iter_args_value
func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%non_live = arith.constant 0 : index
- // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
+ // CHECK-CANONICALIZE: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
%result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
- // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
+ // CHECK-CANONICALIZE: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
%new_live = arith.addi %live_arg, %i : index
- // CHECK: scf.yield [[SUM:%.+]]
+ // CHECK-CANONICALIZE: scf.yield [[SUM:%.+]]
scf.yield %new_live, %non_live_arg : index, index
}
- // CHECK: return [[RESULT]] : index
+ // CHECK-CANONICALIZE: return [[RESULT]] : index
return %result : index
}
@@ -79,7 +81,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @main() {
+ // CHECK-LABEL: @dead_linalg_generic
+ func.func @dead_linalg_generic() {
%cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
%cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
// CHECK-NOT: arith.constant
@@ -229,18 +232,34 @@ func.func @main() -> (i32, i32) {
// anywhere else. Thus, %arg7 is also not kept in the `scf.yield` op.
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
-//
-// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
-// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+
+// CHECK-LABEL: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[p0:.*]] = ub.poison : i32
+// CHECK-NEXT: %[[while:.*]]:3 = scf.while (%{{.*}} = %[[p0]], %[[arg4:.*]] = %[[arg2]]) : (i32, i32) -> (i32, i32, i32) {
+// CHECK-NEXT: %[[add1:.*]] = arith.addi %[[arg4]], %[[arg4]] : i32
+// CHECK-NEXT: %[[p1:.*]] = ub.poison : i32
+// CHECK-NEXT: scf.condition(%[[arg0]]) %[[add1]], %[[arg4]], %[[p1]] : i32, i32, i32
// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-NEXT: scf.yield %[[live_1]] : i32
+// CHECK-NEXT: ^bb0(%{{.*}}: i32, %[[arg6:.*]]: i32, %{{.*}}: i32):
+// CHECK-NEXT: %[[add2:.*]] = arith.addi %[[arg6]], %[[arg6]] : i32
+// CHECK-NEXT: %[[p2:.*]] = ub.poison : i32
+// CHECK-NEXT: scf.yield %[[p2]], %[[add2]] : i32, i32
// CHECK-NEXT: }
-// CHECK-NEXT: return %[[live_and_non_live]]#0
+// CHECK-NEXT: return %[[while]]#0 : i32
// CHECK-NEXT: }
+
+// CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
+// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT: } do {
+// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
+// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0
+// CHECK-CANONICALIZE-NEXT: }
func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) {
%live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) {
%live_0 = arith.addi %arg4, %arg4 : i32
@@ -284,21 +303,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
-// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
-// CHECK-NEXT: %[[c0:.*]] = arith.constant 0
-// CHECK-NEXT: %[[c1:.*]] = arith.constant 1
-// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
-// CHECK-NEXT: func.call @identity() : () -> ()
-// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32
-// CHECK-NEXT: }
-// CHECK: func.func private @identity() {
-// CHECK-NEXT: return
-// CHECK-NEXT: }
+// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
+// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1
+// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> ()
+// CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT: } do {
+// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT: scf.yield %[[arg6]], %[[arg5]] : i32, i32
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#1 : i32
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE: func.func private @identity() {
+// CHECK-CANONICALIZE-NEXT: return
+// CHECK-CANONICALIZE-NEXT: }
func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
@@ -325,17 +344,17 @@ func.func private @identity(%arg1 : i32) -> (i32) {
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
-// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
-// CHECK-NEXT: scf.index_switch %[[arg0]]
-// CHECK-NEXT: case 1 {
-// CHECK-NEXT: %[[c10:.*]] = arith.constant 10
-// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][]
-// CHECK-NEXT: scf.yield
-// CHECK-NEXT: }
-// CHECK-NEXT: default {
-// CHECK-NEXT: }
-// CHECK-NEXT: return
-// CHECK-NEXT: }
+// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CHECK-CANONICALIZE-NEXT: scf.index_switch %[[arg0]]
+// CHECK-CANONICALIZE-NEXT: case 1 {
+// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10
+// CHECK-CANONICALIZE-NEXT: memref.store %[[c10]], %[[arg1]][]
+// CHECK-CANONICALIZE: scf.yield
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE-NEXT: default {
+// CHECK-CANONICALIZE: }
+// CHECK-CANONICALIZE-NEXT: return
+// CHECK-CANONICALIZE-NEXT: }
func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) {
%non_live = scf.index_switch %arg0 -> i32
case 1 {
@@ -539,10 +558,10 @@ module {
}
}
-// CHECK-LABEL: func @test_zero_operands
-// CHECK: memref.alloca_scope
-// CHECK: memref.store
-// CHECK-NOT: memref.alloca_scope.return
+// CHECK-CANONICALIZE-LABEL: func @test_zero_operands
+// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT: memref.store %[[c0]]
+// CHECK-CANONICALIZE-NOT: memref.alloca_scope.return
// -----
@@ -714,3 +733,49 @@ func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64
^bb2:
return %arg1 : i64
}
+
+// -----
+
+// CHECK-LABEL: func @scf_while_dead_iter_args()
+// CHECK: %[[c5:.*]] = arith.constant 5 : i32
+// CHECK: %[[while:.*]]:2 = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> (i32, i32) {
+// CHECK: vector.print %[[arg0]]
+// CHECK: %[[cmpi:.*]] = arith.cmpi
+// CHECK: %[[p0:.*]] = ub.poison : i32
+// CHECK: scf.condition(%[[cmpi]]) %[[arg0]], %[[p0]]
+// CHECK: } do {
+// CHECK: ^bb0(%[[arg1:.*]]: i32, %[[arg2:.*]]: i32):
+// CHECK: %[[p1:.*]] = ub.poison : i32
+// CHECK: scf.yield %[[p1]]
+// CHECK: }
+// CHECK: return %[[while]]#0
+
+// CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args()
+// CHECK-CANONICALIZE: %[[c5:.*]] = arith.constant 5 : i32
+// CHECK-CANONICALIZE: %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 {
+// CHECK-CANONICALIZE: vector.print %[[arg0]]
+// CHECK-CANONICALIZE: %[[cmpi:.*]] = arith.cmpi
+// CHECK-CANONICALIZE: scf.condition(%[[cmpi]]) %[[arg0]]
+// CHECK-CANONICALIZE: } do {
+// CHECK-CANONICALIZE: ^bb0(%[[arg1:.*]]: i32):
+// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32
+// CHECK-CANONICALIZE: scf.yield %[[p0]]
+// CHECK-CANONICALIZE: }
+// CHECK-CANONICALIZE: return %[[while]]
+func.func @scf_while_dead_iter_args() -> i32 {
+ %c5 = arith.constant 5 : i32
+ %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) {
+ vector.print %arg0 : i32
+ // Note: This condition is always "false". (And the liveness analysis
+ // can figure that out.)
+ %cmp2 = arith.cmpi slt, %arg0, %c5 : i32
+ scf.condition(%cmp2) %arg0, %arg0 : i32, i32
+ } do {
+ ^bb0(%arg1: i32, %arg2: i32):
+ %x = scf.execute_region -> i32 {
+ scf.yield %arg2 : i32
+ }
+ scf.yield %x : i32
+ }
+ return %result#0 : i32
+}
More information about the llvm-branch-commits
mailing list