[Mlir-commits] [mlir] [mlir][Transforms] Do not erase IR in `remove-dead-values` (PR #173505)
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 24 10:52:49 PST 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/173505
None
>From 7a318f689ef40f1db3b085b72e3d578fbe8e1405 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 24 Dec 2025 13:26:00 +0000
Subject: [PATCH 1/3] tmp commit
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 137 +++++++++++------------
1 file changed, 66 insertions(+), 71 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 10942923ba1e1..fe58ab4e825bd 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -434,9 +434,14 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG() << "Processing region branch op: "
- << OpWithFlags(regionBranchOp,
- OpPrintingFlags().skipRegions().printGenericOpForm());
+ llvm::errs() << "=== Processing region branch op === "
+ << OpWithFlags(
+ regionBranchOp,
+ OpPrintingFlags().skipRegions().printGenericOpForm())
+ << "\n";
+ // LDBG() << "Processing region branch op: "
+ // << OpWithFlags(regionBranchOp,
+ // OpPrintingFlags().skipRegions().printGenericOpForm());
// Scenario 1. This is the only case where the entire `regionBranchOp`
// is removed. It will not happen in any other scenario. Note that in this
@@ -450,21 +455,21 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
return;
}
- // Mark live results of `regionBranchOp` in `liveResults`.
- auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
- };
+ // Results of the `regionBranchOp` that we want to keep. We definitely want
+ // to keep all live results.
+ BitVector resultsToKeep =
+ markLives(regionBranchOp->getResults(), nonLiveSet, la);
- // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
- auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
- liveArgs[®ion] = regionLiveArgs;
- }
- };
+ // Block arguments of the `regionBranchOp` that we want to keep. We
+ // definitely want to keep all live arguments.
+ DenseMap<Region *, BitVector> argsToKeep;
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
+ SmallVector<Value> arguments(region.front().getArguments());
+ BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+ argsToKeep[®ion] = regionLiveArgs;
+ }
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
@@ -487,37 +492,40 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
return opOperands;
};
- // Mark the non-forwarded operands of `regionBranchOp` in
- // `nonForwardedOperands`.
- auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
- nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- for (OpOperand *opOperand : getForwardedOpOperands(successor))
- nonForwardedOperands.reset(opOperand->getOperandNumber());
- }
- };
+ // Compute the operands of `regionBranchOp` that we want to keep. We
+ // definitely want to keep the non-forwarded operands of the region branch
+ // op. And potentially some forwarded operands (to be determined later).
+ BitVector operandsToKeep;
+ operandsToKeep.resize(regionBranchOp->getNumOperands(), true);
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint::parent())) {
+ for (OpOperand *opOperand : getForwardedOpOperands(successor))
+ operandsToKeep.reset(opOperand->getOperandNumber());
+ }
- // Mark the non-forwarded terminator operands of the various regions of
- // `regionBranchOp` in `nonForwardedRets`.
- auto markNonForwardedReturnValues =
- [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- // TODO: this isn't correct in face of multiple terminators.
- Operation *terminator = region.front().getTerminator();
- nonForwardedRets[terminator] =
- BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- for (OpOperand *opOperand :
- getForwardedOpOperands(successor, terminator))
- nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
- }
- }
- };
+ // Compute the operands of each terminator that we want to keep. We definitely
+ // want to keep the non-forwarded terminator operands. And potentially some
+ // forwarded terminator operands (to be determined later).
+ DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
+ // TODO: this isn't correct in face of multiple terminators.
+ Operation *terminator = region.front().getTerminator();
+ terminatorOperandsToKeep[terminator] =
+ BitVector(terminator->getNumOperands(), true);
+ for (const RegionSuccessor &successor : getSuccessors(RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator)))) {
+ for (OpOperand *opOperand :
+ getForwardedOpOperands(successor, terminator)) {
+ terminatorOperandsToKeep[terminator].reset(
+ opOperand->getOperandNumber());
+ llvm::errs() << " reset terminator " << terminator->getName()
+ << " operand: " << opOperand->getOperandNumber()
+ << " to false\n";
+ }
+ }
+ }
// Update `valuesToKeep` (which is expected to correspond to operands or
// terminator operands) based on `resultsToKeep` and `argsToKeep`, given
@@ -663,32 +671,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// At this point, we know that every non-forwarded operand of `regionBranchOp`
// is live.
- // Stores the results of `regionBranchOp` that we want to keep.
- BitVector resultsToKeep;
- // Stores the mapping from regions of `regionBranchOp` to their arguments that
- // we want to keep.
- DenseMap<Region *, BitVector> argsToKeep;
- // Stores the operands of `regionBranchOp` that we want to keep.
- BitVector operandsToKeep;
- // Stores the mapping from region terminators in `regionBranchOp` to their
- // operands that we want to keep.
- DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
// Initializing the above variables...
- // The live results of `regionBranchOp` definitely need to be kept.
- markLiveResults(resultsToKeep);
- // Similarly, the live arguments of the regions in `regionBranchOp` definitely
- // need to be kept.
- markLiveArgs(argsToKeep);
- // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
- // A live forwarded operand can be removed but no non-forwarded operand can be
- // removed since it "controls" the flow of data in this control flow op.
- markNonForwardedOperands(operandsToKeep);
- // Similarly, the non-forwarded terminator operands of the regions in
- // `regionBranchOp` definitely need to be kept.
- markNonForwardedReturnValues(terminatorOperandsToKeep);
-
// Mark the values (results, arguments, operands, and terminator operands)
// that we want to keep.
markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
@@ -712,6 +696,14 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
+ llvm::errs() << "(2.c) terminator: "
+ << OpWithFlags(
+ terminator,
+ OpPrintingFlags().skipRegions().printGenericOpForm())
+ << "\n";
+ for (auto i : terminatorOperandsToKeep[terminator].set_bits()) {
+ llvm::errs() << " keep operand: " << i << "\n";
+ }
cl.operands.push_back(
{terminator, terminatorOperandsToKeep[terminator].flip()});
}
@@ -721,6 +713,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
resultsToRemove);
cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
+
+ llvm::errs() << "=== DONE PROCESSING REGION BRANCH OP ===\n";
}
/// Steps to process a `BranchOpInterface` operation:
@@ -984,6 +978,7 @@ void RemoveDeadValues::runOnOperation() {
});
cleanUpDeadVals(finalCleanupList);
+ getOperation()->dump();
}
std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
>From a117ebf6a76a8edac44a59f5a199bbbf8a723aee Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 24 Dec 2025 15:19:47 +0000
Subject: [PATCH 2/3] simple test working
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 353 ++++++++---------------
1 file changed, 127 insertions(+), 226 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index fe58ab4e825bd..242b01d797966 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -455,231 +455,149 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
return;
}
- // Results of the `regionBranchOp` that we want to keep. We definitely want
- // to keep all live results.
- BitVector resultsToKeep =
- markLives(regionBranchOp->getResults(), nonLiveSet, la);
-
- // Block arguments of the `regionBranchOp` that we want to keep. We
- // definitely want to keep all live arguments.
- DenseMap<Region *, BitVector> argsToKeep;
+ // Compute values that we definitely want to keep.
+ DenseSet<Value> valuesToKeep;
+ for (Value result : regionBranchOp->getResults()) {
+ if (hasLive(result, nonLiveSet, la))
+ valuesToKeep.insert(result);
+ }
for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
- argsToKeep[®ion] = regionLiveArgs;
+ for (Value arg : region.front().getArguments()) {
+ if (hasLive(arg, nonLiveSet, la))
+ valuesToKeep.insert(arg);
+ }
+ }
+
+ llvm::errs() << "*** INITIAL VALUES TO KEEP: \n";
+ for (Value value : valuesToKeep) {
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
+ llvm::errs() << " keep: " << arg << " (arg " << arg.getArgNumber() << " in region " << arg.getOwner()->getParent()->getRegionNumber() << ")\n";
+ }
}
- // Return the successors of `region` if the latter is not null. Else return
- // the successors of `regionBranchOp`.
- auto getSuccessors = [&](RegionBranchPoint point) {
+ auto forAllRegionBranchPoints = [&](auto callback) -> bool {
+ bool result = false;
+ result |= callback(RegionBranchPoint::parent());
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
+ result |= callback(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>(region.front().getTerminator())));
+ }
+ return result;
+ };
+
+ // TODO: Iterate to fixed point?
+ auto fixupLiveSet = [&](RegionBranchPoint point) {
+ bool changedLiveSet = false;
+ // Mapping of operands to successor inputs. An operand can be forwarded to
+ // multiple successors.
+ DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
- return successors;
- };
+ for (const RegionSuccessor &successor : successors) {
+ // Handle branch from point --> successor.
+ ValueRange argsOrResults = successor.getSuccessorInputs();
+ OperandRange operands =
+ point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor) : cast<RegionBranchTerminatorOpInterface>(point.getTerminatorPredecessorOrNull())
+ .getSuccessorOperands(successor);
+ assert(argsOrResults.size() == operands.size() && "expected the same number of successor inputs as forwarded operands");
+
+ for (auto [opOperand, input] : llvm::zip(operandsToOpOperands(operands), argsOrResults)) {
+ operandToSuccessorInputs[opOperand].push_back(input);
+ }
+ }
- // Return the operands of `terminator` that are forwarded to `successor` if
- // the former is not null. Else return the operands of `regionBranchOp`
- // forwarded to `successor`.
- auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
- Operation *terminator = nullptr) {
- OperandRange operands =
- terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(successor)
- : regionBranchOp.getEntrySuccessorOperands(successor);
- SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
- return opOperands;
- };
+ for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+ // If one of the successor inputs is live, the respective operand must be
+ // kept. In that case, all matching successor inputs must be kept.
+ bool anyAlive = llvm::any_of(successorInputs, [&](Value input) { return valuesToKeep.contains(input); });
+ if (!anyAlive) continue;
+
+ //llvm::errs() << "LIVE SET:\n";
+ //for (Value value : successorInputs) {
+ // if (auto arg = dyn_cast<BlockArgument>(value)) {
+ // llvm::errs() << " live: " << arg << " (arg " << arg.getArgNumber() << " in region " << arg.getOwner()->getParent()->getRegionNumber() << ")\n";
+ // }
+ //}
+
+ for (Value input : successorInputs) {
+ changedLiveSet |= valuesToKeep.insert(input).second;
+ }
+ OpOperand *opr = opOperand;
+ Value v = opr->get();
+ changedLiveSet |= valuesToKeep.insert(v).second;
+ }
- // Compute the operands of `regionBranchOp` that we want to keep. We
- // definitely want to keep the non-forwarded operands of the region branch
- // op. And potentially some forwarded operands (to be determined later).
- BitVector operandsToKeep;
- operandsToKeep.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- for (OpOperand *opOperand : getForwardedOpOperands(successor))
- operandsToKeep.reset(opOperand->getOperandNumber());
+ return changedLiveSet;
+ };
+
+ // Iterate to fixed point. TODO: Add example.
+ while (forAllRegionBranchPoints(fixupLiveSet)) {}
+
+ llvm::errs() << "*** FIXED UP VALUES TO KEEP: \n";
+ for (Value value : valuesToKeep) {
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
+ llvm::errs() << " keep: " << arg << " (arg " << arg.getArgNumber() << " in region " << arg.getOwner()->getParent()->getRegionNumber() << ")\n";
+ }
}
- // Compute the operands of each terminator that we want to keep. We definitely
- // want to keep the non-forwarded terminator operands. And potentially some
- // forwarded terminator operands (to be determined later).
- DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
+ // Now we know which values are needed. Compute the operands that must be
+ // kept. First initialize all operands to "true".
+ DenseMap<Operation *, BitVector> operandsToKeep;
+ operandsToKeep[regionBranchOp.getOperation()] = BitVector(regionBranchOp->getNumOperands(), true);
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- // TODO: this isn't correct in face of multiple terminators.
- Operation *terminator = region.front().getTerminator();
- terminatorOperandsToKeep[terminator] =
- BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- for (OpOperand *opOperand :
- getForwardedOpOperands(successor, terminator)) {
- terminatorOperandsToKeep[terminator].reset(
- opOperand->getOperandNumber());
- llvm::errs() << " reset terminator " << terminator->getName()
- << " operand: " << opOperand->getOperandNumber()
- << " to false\n";
+ operandsToKeep[region.front().getTerminator()] = BitVector(region.front().getTerminator()->getNumOperands(), true);
+ }
+ auto computeOperandsToKeep = [&](RegionBranchPoint point) {
+ // Mapping of operands to successor inputs. An operand can be forwarded to
+ // multiple successors.
+ DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
+ SmallVector<RegionSuccessor> successors;
+ regionBranchOp.getSuccessorRegions(point, successors);
+ for (const RegionSuccessor &successor : successors) {
+ // Handle branch from point --> successor.
+ ValueRange argsOrResults = successor.getSuccessorInputs();
+ OperandRange operands =
+ point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor) : cast<RegionBranchTerminatorOpInterface>(point.getTerminatorPredecessorOrNull())
+ .getSuccessorOperands(successor);
+ assert(argsOrResults.size() == operands.size() && "expected the same number of successor inputs as forwarded operands");
+
+ for (auto [opOperand, input] : llvm::zip(operandsToOpOperands(operands), argsOrResults)) {
+ operandToSuccessorInputs[opOperand].push_back(input);
}
}
- }
-
- // Update `valuesToKeep` (which is expected to correspond to operands or
- // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
- // `region`. When `valuesToKeep` correspond to operands, `region` is null.
- // Else, `region` is the parent region of the terminator.
- auto updateOperandsOrTerminatorOperandsToKeep =
- [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
- DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
- Operation *terminator =
- region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
- terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
-
- for (const RegionSuccessor &successor : getSuccessors(point)) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- size_t operandNum = opOperand->getOperandNumber();
- bool updateBasedOn =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
- }
- }
- };
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
- // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
- // value is modified, else, false.
- auto recomputeResultsAndArgsToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
- bool &resultsOrArgsToKeepChanged) {
- resultsOrArgsToKeepChanged = false;
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- operandsToKeep[opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
-
- // Recompute `resultsToKeep` and `argsToKeep` based on
- // `terminatorOperandsToKeep`.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- terminatorOperandsToKeep[region.back().getTerminator()]
- [opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
- }
- };
-
- // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
- // `operandsToKeep`, and `terminatorOperandsToKeep`.
- auto markValuesToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
- bool resultsOrArgsToKeepChanged = true;
- // We keep updating and recomputing the values until we reach a point
- // where they stop changing.
- while (resultsOrArgsToKeepChanged) {
- // Update the operands that need to be kept.
- updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
- resultsToKeep, argsToKeep);
-
- // Update the terminator operands that need to be kept.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- updateOperandsOrTerminatorOperandsToKeep(
- terminatorOperandsToKeep[region.back().getTerminator()],
- resultsToKeep, argsToKeep, ®ion);
- }
- // Recompute the results and arguments that need to be kept.
- recomputeResultsAndArgsToKeep(
- resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
- }
- };
-
- // Scenario 2.
- // At this point, we know that every non-forwarded operand of `regionBranchOp`
- // is live.
-
- // Initializing the above variables...
-
- // Mark the values (results, arguments, operands, and terminator operands)
- // that we want to keep.
- markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep);
+ for (const auto &[opOperand, successorInputs] : operandToSuccessorInputs) {
+ bool anyAlive = llvm::any_of(successorInputs, [&](Value input) { return valuesToKeep.contains(input); });
+ if (!anyAlive) {
+ // This is a forwarded operand but none of its successors is alive.
+ operandsToKeep[opOperand->getOwner()].reset(opOperand->getOperandNumber());
+ }
+ }
+ return false;
+ };
+ forAllRegionBranchPoints(computeOperandsToKeep);
- // Do (1).
- cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
+ DenseMap<Region *, BitVector> argsToKeep;
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
+ argsToKeep[®ion] = BitVector(region.front().getArguments().size(), false);
+ for (BlockArgument argument : region.front().getArguments())
+ if (valuesToKeep.contains(argument))
+ argsToKeep[®ion].set(argument.getArgNumber());
+ }
+ BitVector resultsToKeep(regionBranchOp->getNumResults(), false);
+ for (OpResult result : regionBranchOp->getResults())
+ if (valuesToKeep.contains(result))
+ resultsToKeep.set(result.getResultNumber());
+
+ // Store results in `cl`.
+ for (auto [op, vec] : operandsToKeep) {
+ cl.operands.push_back({op, vec.flip()});
+ }
// Do (2.a) and (2.b).
for (Region ®ion : regionBranchOp->getRegions()) {
@@ -691,23 +609,6 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
argsToRemove);
}
- // Do (2.c).
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- Operation *terminator = region.front().getTerminator();
- llvm::errs() << "(2.c) terminator: "
- << OpWithFlags(
- terminator,
- OpPrintingFlags().skipRegions().printGenericOpForm())
- << "\n";
- for (auto i : terminatorOperandsToKeep[terminator].set_bits()) {
- llvm::errs() << " keep operand: " << i << "\n";
- }
- cl.operands.push_back(
- {terminator, terminatorOperandsToKeep[terminator].flip()});
- }
-
// Do (3) and (4).
BitVector resultsToRemove = resultsToKeep.flip();
collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
>From 88c7c3a85f60f10d915369937fd3c67754c9a96e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 24 Dec 2025 18:50:35 +0000
Subject: [PATCH 3/3] draft: do not erase IR, just replace uses
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 161 ++++---------------
mlir/test/Transforms/remove-dead-values.mlir | 111 +++++++------
2 files changed, 95 insertions(+), 177 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 242b01d797966..e0355a1c197bb 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -434,14 +434,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- llvm::errs() << "=== Processing region branch op === "
- << OpWithFlags(
- regionBranchOp,
- OpPrintingFlags().skipRegions().printGenericOpForm())
- << "\n";
- // LDBG() << "Processing region branch op: "
- // << OpWithFlags(regionBranchOp,
- // OpPrintingFlags().skipRegions().printGenericOpForm());
+ LDBG() << "Processing region branch op: "
+ << OpWithFlags(regionBranchOp,
+ OpPrintingFlags().skipRegions().printGenericOpForm());
// Scenario 1. This is the only case where the entire `regionBranchOp`
// is removed. It will not happen in any other scenario. Note that in this
@@ -462,36 +457,28 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
valuesToKeep.insert(result);
}
for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
for (Value arg : region.front().getArguments()) {
if (hasLive(arg, nonLiveSet, la))
valuesToKeep.insert(arg);
}
}
- llvm::errs() << "*** INITIAL VALUES TO KEEP: \n";
- for (Value value : valuesToKeep) {
- if (auto arg = dyn_cast<BlockArgument>(value)) {
- llvm::errs() << " keep: " << arg << " (arg " << arg.getArgNumber() << " in region " << arg.getOwner()->getParent()->getRegionNumber() << ")\n";
- }
- }
-
- auto forAllRegionBranchPoints = [&](auto callback) -> bool {
- bool result = false;
- result |= callback(RegionBranchPoint::parent());
+ auto forAllRegionBranchPoints = [&](auto callback) {
+ callback(RegionBranchPoint::parent());
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- result |= callback(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>(region.front().getTerminator())));
+ callback(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>(
+ region.front().getTerminator())));
}
- return result;
};
- // TODO: Iterate to fixed point?
- auto fixupLiveSet = [&](RegionBranchPoint point) {
- bool changedLiveSet = false;
- // Mapping of operands to successor inputs. An operand can be forwarded to
- // multiple successors.
- DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
+ // Mapping from operands to forwarded successor inputs. An operand can be
+ // forwarded to multiple successors.
+ DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
+ auto helper = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
for (const RegionSuccessor &successor : successors) {
@@ -502,120 +489,31 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
.getSuccessorOperands(successor);
assert(argsOrResults.size() == operands.size() && "expected the same number of successor inputs as forwarded operands");
- for (auto [opOperand, input] : llvm::zip(operandsToOpOperands(operands), argsOrResults)) {
+ for (auto [opOperand, input] :
+ llvm::zip_equal(operandsToOpOperands(operands), argsOrResults)) {
operandToSuccessorInputs[opOperand].push_back(input);
}
}
-
- for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
- // If one of the successor inputs is live, the respective operand must be
- // kept. In that case, all matching successor inputs must be kept.
- bool anyAlive = llvm::any_of(successorInputs, [&](Value input) { return valuesToKeep.contains(input); });
- if (!anyAlive) continue;
-
- //llvm::errs() << "LIVE SET:\n";
- //for (Value value : successorInputs) {
- // if (auto arg = dyn_cast<BlockArgument>(value)) {
- // llvm::errs() << " live: " << arg << " (arg " << arg.getArgNumber() << " in region " << arg.getOwner()->getParent()->getRegionNumber() << ")\n";
- // }
- //}
-
- for (Value input : successorInputs) {
- changedLiveSet |= valuesToKeep.insert(input).second;
- }
- OpOperand *opr = opOperand;
- Value v = opr->get();
- changedLiveSet |= valuesToKeep.insert(v).second;
- }
-
- return changedLiveSet;
};
-
- // Iterate to fixed point. TODO: Add example.
- while (forAllRegionBranchPoints(fixupLiveSet)) {}
-
- llvm::errs() << "*** FIXED UP VALUES TO KEEP: \n";
- for (Value value : valuesToKeep) {
- if (auto arg = dyn_cast<BlockArgument>(value)) {
- llvm::errs() << " keep: " << arg << " (arg " << arg.getArgNumber() << " in region " << arg.getOwner()->getParent()->getRegionNumber() << ")\n";
- }
- }
-
- // Now we know which values are needed. Compute the operands that must be
- // kept. First initialize all operands to "true".
- DenseMap<Operation *, BitVector> operandsToKeep;
- operandsToKeep[regionBranchOp.getOperation()] = BitVector(regionBranchOp->getNumOperands(), true);
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- operandsToKeep[region.front().getTerminator()] = BitVector(region.front().getTerminator()->getNumOperands(), true);
- }
- auto computeOperandsToKeep = [&](RegionBranchPoint point) {
- // Mapping of operands to successor inputs. An operand can be forwarded to
- // multiple successors.
- DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
- SmallVector<RegionSuccessor> successors;
- regionBranchOp.getSuccessorRegions(point, successors);
- for (const RegionSuccessor &successor : successors) {
- // Handle branch from point --> successor.
- ValueRange argsOrResults = successor.getSuccessorInputs();
- OperandRange operands =
- point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor) : cast<RegionBranchTerminatorOpInterface>(point.getTerminatorPredecessorOrNull())
- .getSuccessorOperands(successor);
- assert(argsOrResults.size() == operands.size() && "expected the same number of successor inputs as forwarded operands");
-
- for (auto [opOperand, input] : llvm::zip(operandsToOpOperands(operands), argsOrResults)) {
- operandToSuccessorInputs[opOperand].push_back(input);
- }
- }
- for (const auto &[opOperand, successorInputs] : operandToSuccessorInputs) {
- bool anyAlive = llvm::any_of(successorInputs, [&](Value input) { return valuesToKeep.contains(input); });
- if (!anyAlive) {
- // This is a forwarded operand but none of its successors is alive.
- operandsToKeep[opOperand->getOwner()].reset(opOperand->getOperandNumber());
- }
- }
- return false;
- };
- forAllRegionBranchPoints(computeOperandsToKeep);
+ // Iterate to fixed point. TODO: Add example.
+ forAllRegionBranchPoints(helper);
- DenseMap<Region *, BitVector> argsToKeep;
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
+ for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+ // If one of the successor inputs is live, the respective operand must be
+ // kept. In that case, all matching successor inputs must be kept.
+ bool anyAlive = llvm::any_of(successorInputs, [&](Value input) {
+ return valuesToKeep.contains(input);
+ });
+ if (anyAlive)
continue;
- argsToKeep[®ion] = BitVector(region.front().getArguments().size(), false);
- for (BlockArgument argument : region.front().getArguments())
- if (valuesToKeep.contains(argument))
- argsToKeep[®ion].set(argument.getArgNumber());
- }
- BitVector resultsToKeep(regionBranchOp->getNumResults(), false);
- for (OpResult result : regionBranchOp->getResults())
- if (valuesToKeep.contains(result))
- resultsToKeep.set(result.getResultNumber());
-
- // Store results in `cl`.
- for (auto [op, vec] : operandsToKeep) {
- cl.operands.push_back({op, vec.flip()});
- }
- // Do (2.a) and (2.b).
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- BitVector argsToRemove = argsToKeep[®ion].flip();
- cl.blocks.push_back({®ion.front(), argsToRemove});
- collectNonLiveValues(nonLiveSet, region.front().getArguments(),
- argsToRemove);
+ // All successor inputs are dead: ub.poison can be passed as operand.
+ IRRewriter rewriter(opOperand->getOwner());
+ auto undefVal = ub::PoisonOp::create(rewriter, opOperand->get().getLoc(),
+ opOperand->get().getType());
+ opOperand->set(undefVal);
}
-
- // Do (3) and (4).
- BitVector resultsToRemove = resultsToKeep.flip();
- collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
- resultsToRemove);
- cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
-
- llvm::errs() << "=== DONE PROCESSING REGION BRANCH OP ===\n";
}
/// Steps to process a `BranchOpInterface` operation:
@@ -879,7 +777,6 @@ void RemoveDeadValues::runOnOperation() {
});
cleanUpDeadVals(finalCleanupList);
- getOperation()->dump();
}
std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 71306676d48e9..32acf518f9871 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -60,14 +60,14 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%non_live = arith.constant 0 : index
- // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
+ // CaHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
%result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
- // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
+ // CaHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
%new_live = arith.addi %live_arg, %i : index
- // CHECK: scf.yield [[SUM:%.+]]
+ // CaHECK: scf.yield [[SUM:%.+]]
scf.yield %new_live, %non_live_arg : index, index
}
- // CHECK: return [[RESULT]] : index
+ // CaHECK: return [[RESULT]] : index
return %result : index
}
@@ -79,7 +79,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @main() {
+ // CHECK-LABEL: @dead_linalg_generic
+ func.func @dead_linalg_generic() {
%cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
%cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
// CHECK-NOT: arith.constant
@@ -230,17 +231,17 @@ func.func @main() -> (i32, i32) {
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
-// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
-// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-NEXT: scf.yield %[[live_1]] : i32
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[live_and_non_live]]#0
-// CHECK-NEXT: }
+// CaHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CaHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CaHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
+// CaHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CaHECK-NEXT: } do {
+// CaHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CaHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
+// CaHECK-NEXT: scf.yield %[[live_1]] : i32
+// CaHECK-NEXT: }
+// CaHECK-NEXT: return %[[live_and_non_live]]#0
+// CaHECK-NEXT: }
func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) {
%live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) {
%live_0 = arith.addi %arg4, %arg4 : i32
@@ -284,21 +285,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
-// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
-// CHECK-NEXT: %[[c0:.*]] = arith.constant 0
-// CHECK-NEXT: %[[c1:.*]] = arith.constant 1
-// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
-// CHECK-NEXT: func.call @identity() : () -> ()
-// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32
-// CHECK-NEXT: }
-// CHECK: func.func private @identity() {
-// CHECK-NEXT: return
-// CHECK-NEXT: }
+// CaHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
+// CaHECK-NEXT: %[[c0:.*]] = arith.constant 0
+// CaHECK-NEXT: %[[c1:.*]] = arith.constant 1
+// CaHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CaHECK-NEXT: func.call @identity() : () -> ()
+// CaHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
+// CaHECK-NEXT: } do {
+// CaHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CaHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32
+// CaHECK-NEXT: }
+// CaHECK-NEXT: return %[[live_and_non_live]]#0 : i32
+// CaHECK-NEXT: }
+// CaHECK: func.func private @identity() {
+// CaHECK-NEXT: return
+// CaHECK-NEXT: }
func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
@@ -325,17 +326,17 @@ func.func private @identity(%arg1 : i32) -> (i32) {
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
-// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
-// CHECK-NEXT: scf.index_switch %[[arg0]]
-// CHECK-NEXT: case 1 {
-// CHECK-NEXT: %[[c10:.*]] = arith.constant 10
-// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][]
-// CHECK-NEXT: scf.yield
-// CHECK-NEXT: }
-// CHECK-NEXT: default {
-// CHECK-NEXT: }
-// CHECK-NEXT: return
-// CHECK-NEXT: }
+// CaHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CaHECK-NEXT: scf.index_switch %[[arg0]]
+// CaHECK-NEXT: case 1 {
+// CaHECK-NEXT: %[[c10:.*]] = arith.constant 10
+// CaHECK-NEXT: memref.store %[[c10]], %[[arg1]][]
+// CaHECK-NEXT: scf.yield
+// CaHECK-NEXT: }
+// CaHECK-NEXT: default {
+// CaHECK-NEXT: }
+// CaHECK-NEXT: return
+// CaHECK-NEXT: }
func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) {
%non_live = scf.index_switch %arg0 -> i32
case 1 {
@@ -540,9 +541,9 @@ module {
}
// CHECK-LABEL: func @test_zero_operands
-// CHECK: memref.alloca_scope
-// CHECK: memref.store
-// CHECK-NOT: memref.alloca_scope.return
+// CaHECK: memref.alloca_scope
+// CaHECK: memref.store
+// CaHECK-NOT: memref.alloca_scope.return
// -----
@@ -714,3 +715,23 @@ func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64
^bb2:
return %arg1 : i64
}
+
+// -----
+
+func.func @scf_while_dead_iter_args() -> i32 {
+ %c5 = arith.constant 5 : i32
+ %false = arith.constant false
+ %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) {
+ vector.print %arg0 : i32
+ %cmp2 = arith.cmpi slt, %arg0, %c5 : i32
+ scf.condition(%cmp2) {tag = "scf.condition"} %arg0, %arg0 : i32, i32
+ } do {
+ ^bb0(%arg1: i32, %arg2: i32):
+ %x = scf.execute_region -> i32 {
+ scf.yield %arg2 : i32
+ }
+ // TODO: not working yet when yielding %x instead of %arg2.
+ scf.yield {tag = "scf.yield"} %arg2 : i32
+ } attributes {tag = "scf.while"}
+ return %result#0 : i32
+}
More information about the Mlir-commits
mailing list