[Mlir-commits] [mlir] [mlir][Transforms] `remove-dead-values`: Rely on canonicalizer for region simplification (PR #173505)
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 1 10:16:23 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173505
>From c6ed9f774e860d8cfe42449595bb133bd72b1748 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 1 Jan 2026 16:40:56 +0000
Subject: [PATCH 1/2] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults`
convenience helper
---
mlir/include/mlir/IR/PatternMatch.h | 6 +
mlir/lib/Dialect/SCF/IR/SCF.cpp | 196 +++++++----------------
mlir/lib/IR/PatternMatch.cpp | 36 +++++
mlir/lib/Transforms/RemoveDeadValues.cpp | 31 +---
4 files changed, 105 insertions(+), 164 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 1caab24ac7295..83477c79ff582 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -545,6 +545,12 @@ class RewriterBase : public OpBuilder {
/// This method erases all operations in a block.
virtual void eraseBlock(Block *block);
+ /// Erase the specified results of the given operation. Results cannot be
+ /// erased directly, so the implementation creates a new replacement
+ /// operation and erases the original operation. The new operation is
+ /// returned.
+ Operation *eraseOpResults(Operation *op, const BitVector &eraseIndices);
+
/// Inline the operations of block 'source' into block 'dest' before the given
/// position. The source block will be deleted and must have no uses.
/// 'argValues' is used to replace the block arguments of 'source'.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0a123112cf68f..d4e341416fd1b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1763,89 +1763,55 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
// ForallOp::getCombiningOps(iter_arg).
//
// Based on the check we maintain the following :-
- // a. `resultToDelete` - i-th result of scf.forall that'll be
- // deleted.
- // b. `resultToReplace` - i-th result of the old scf.forall
- // whose uses will be replaced by the new scf.forall.
- // c. `newOuts` - the shared_outs' operand of the new scf.forall
- // corresponding to the i-th result with at least one use.
- SetVector<OpResult> resultToDelete;
- SmallVector<Value> resultToReplace;
+ // a. op results, block arguments, outputs to delete
+ // b. new outputs (i.e., outputs to retain)
+ SmallVector<Value> resultsToDelete;
+ SmallVector<Value> outsToDelete;
+ SmallVector<BlockArgument> blockArgsToDelete;
SmallVector<Value> newOuts;
+ BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
+ BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
+ false);
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
- resultToDelete.insert(result);
+ resultsToDelete.push_back(result);
+ outsToDelete.push_back(opOperand->get());
+ blockArgsToDelete.push_back(blockArg);
+ resultIndicesToDelete[result.getResultNumber()] = true;
+ blockIndicesToDelete[blockArg.getArgNumber()] = true;
} else {
- resultToReplace.push_back(result);
newOuts.push_back(opOperand->get());
}
}
// Return early if all results of scf.forall have at least one use and being
// modified within the loop.
- if (resultToDelete.empty())
+ if (resultsToDelete.empty())
return failure();
- // Step 2: For the the i-th result, do the following :-
- // a. Fetch the corresponding BlockArgument.
- // b. Look for store ops (currently tensor.parallel_insert_slice)
- // with the BlockArgument as its destination operand.
- // c. Remove the operations fetched in b.
- for (OpResult result : resultToDelete) {
- OpOperand *opOperand = forallOp.getTiedOpOperand(result);
- BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ // Step 2: Erase combining ops and replace uses of deleted results and
+ // block arguments with the corresponding outputs.
+ for (auto blockArg : blockArgsToDelete) {
SmallVector<Operation *> combiningOps =
forallOp.getCombiningOps(blockArg);
for (Operation *combiningOp : combiningOps)
rewriter.eraseOp(combiningOp);
}
-
- // Step 3. Create a new scf.forall op with the new shared_outs' operands
- // fetched earlier
- auto newForallOp = scf::ForallOp::create(
- rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
- forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
- forallOp.getMapping(),
- /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
- // Step 4. Merge the block of the old scf.forall into the newly created
- // scf.forall using the new set of arguments.
- Block *loopBody = forallOp.getBody();
- Block *newLoopBody = newForallOp.getBody();
- ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
- // Form initial new bbArg list with just the control operands of the new
- // scf.forall op.
- SmallVector<Value> newBlockArgs =
- llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
- [](BlockArgument b) -> Value { return b; });
- Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
- unsigned index = 0;
- // Take the new corresponding bbArg if the old bbArg was used as a
- // destination in the in_parallel op. For all other bbArgs, use the
- // corresponding init_arg from the old scf.forall op.
- for (OpResult result : forallOp.getResults()) {
- if (resultToDelete.count(result)) {
- newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
- } else {
- newBlockArgs.push_back(newSharedOutsArgs[index++]);
- }
+ for (auto [blockArg, out, result] :
+ llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+ rewriter.replaceAllUsesWith(blockArg, out);
+ rewriter.replaceAllUsesWith(result, out);
}
- rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
-
- // Step 5. Replace the uses of result of old scf.forall with that of the new
- // scf.forall.
- for (auto &&[oldResult, newResult] :
- llvm::zip(resultToReplace, newForallOp->getResults()))
- rewriter.replaceAllUsesWith(oldResult, newResult);
-
- // Step 6. Replace the uses of those values that either has no use or are
- // not being modified within the loop with the corresponding
- // OpOperand.
- for (OpResult oldResult : resultToDelete)
- rewriter.replaceAllUsesWith(oldResult,
- forallOp.getTiedOpOperand(oldResult)->get());
+ forallOp.getBody()->eraseArguments(blockIndicesToDelete);
+
+ // Step 3. Create a new scf.forall op with only the shared_outs/results
+ // that should be retained.
+ auto newForallOp = cast<scf::ForallOp>(
+ rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
+ newForallOp.getOutputsMutable().assign(newOuts);
+
return success();
}
};
@@ -2413,53 +2379,27 @@ namespace {
struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
- void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
- PatternRewriter &rewriter) const {
- // Move all operations to the destination block.
- rewriter.mergeBlocks(source, dest);
- // Replace the yield op by one that returns only the used values.
- auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
- SmallVector<Value, 4> usedOperands;
- llvm::transform(usedResults, std::back_inserter(usedOperands),
- [&](OpResult result) {
- return yieldOp.getOperand(result.getResultNumber());
- });
- rewriter.modifyOpInPlace(yieldOp,
- [&]() { yieldOp->setOperands(usedOperands); });
- }
-
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
- // Compute the list of used results.
- SmallVector<OpResult, 4> usedResults;
- llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
- [](OpResult result) { return !result.use_empty(); });
-
- // Replace the operation if only a subset of its results have uses.
- if (usedResults.size() == op.getNumResults())
- return failure();
-
- // Compute the result types of the replacement operation.
- SmallVector<Type, 4> newTypes;
- llvm::transform(usedResults, std::back_inserter(newTypes),
- [](OpResult result) { return result.getType(); });
+ // Compute the list of unused results.
+ BitVector toErase(op.getNumResults(), false);
+ for (auto [idx, result] : llvm::enumerate(op.getResults()))
+ if (result.use_empty())
+ toErase[idx] = true;
+ if (toErase.none())
+ return rewriter.notifyMatchFailure(op, "no results to erase");
+
+ // Erase results.
+ auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase));
+
+ // Erase operands.
+ rewriter.modifyOpInPlace(newOp.thenYield(), [&]() {
+ newOp.thenYield()->eraseOperands(toErase);
+ });
+ rewriter.modifyOpInPlace(newOp.elseYield(), [&]() {
+ newOp.elseYield()->eraseOperands(toErase);
+ });
- // Create a replacement operation with empty then and else regions.
- auto newOp =
- IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
- rewriter.createBlock(&newOp.getThenRegion());
- rewriter.createBlock(&newOp.getElseRegion());
-
- // Move the bodies and replace the terminators (note there is a then and
- // an else region since the operation returns results).
- transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
- transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
-
- // Replace the operation by the new one.
- SmallVector<Value, 4> repResults(op.getNumResults());
- for (const auto &en : llvm::enumerate(usedResults))
- repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
- rewriter.replaceOp(op, repResults);
return success();
}
};
@@ -4720,43 +4660,27 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
PatternRewriter &rewriter) const override {
// Find dead results.
BitVector deadResults(op.getNumResults(), false);
- SmallVector<Type> newResultTypes;
- for (auto [idx, result] : llvm::enumerate(op.getResults())) {
- if (!result.use_empty()) {
- newResultTypes.push_back(result.getType());
- } else {
+ for (auto [idx, result] : llvm::enumerate(op.getResults()))
+ if (result.use_empty())
deadResults[idx] = true;
- }
- }
if (!deadResults.any())
return rewriter.notifyMatchFailure(op, "no dead results to fold");
- // Create new op without dead results and inline case regions.
- auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
- op.getArg(), op.getCases(),
- op.getCaseRegions().size());
- auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
- rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
- // Remove respective operands from yield op.
- Operation *terminator = newRegion.front().getTerminator();
+ // Erase dead results.
+ auto newOp =
+ cast<scf::IndexSwitchOp>(rewriter.eraseOpResults(op, deadResults));
+
+ // Erase operands from yield ops.
+ auto updateCaseRegion = [&](Region ®ion) {
+ Operation *terminator = region.front().getTerminator();
assert(isa<YieldOp>(terminator) && "expected yield op");
rewriter.modifyOpInPlace(
terminator, [&]() { terminator->eraseOperands(deadResults); });
};
- for (auto [oldRegion, newRegion] :
- llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
- inlineCaseRegion(oldRegion, newRegion);
- inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
- // Replace op with new op.
- SmallVector<Value> newResults(op.getNumResults(), Value());
- unsigned nextNewResult = 0;
- for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
- if (deadResults[idx])
- continue;
- newResults[idx] = newOp.getResult(nextNewResult++);
- }
- rewriter.replaceOp(op, newResults);
+ updateCaseRegion(newOp.getDefaultRegion());
+ for (Region &caseRegion : newOp.getCaseRegions())
+ updateCaseRegion(caseRegion);
+
return success();
}
};
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 226e4e518d3e0..913063c87e1fa 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,6 +244,42 @@ void RewriterBase::eraseBlock(Block *block) {
block->erase();
}
+Operation *RewriterBase::eraseOpResults(Operation *op,
+ const BitVector &eraseIndices) {
+ assert(op->getNumResults() == eraseIndices.size() &&
+ "number of op results and bitvector size must match");
+
+ // Gather new result types.
+ SmallVector<Type> newResultTypes;
+ newResultTypes.reserve(op->getNumResults() - eraseIndices.count());
+ for (OpResult result : op->getResults())
+ if (!eraseIndices[result.getResultNumber()])
+ newResultTypes.push_back(result.getType());
+
+ // Create a new operation and inline all regions.
+ InsertionGuard g(*this);
+ setInsertionPoint(op);
+ OperationState state(op->getLoc(), op->getName().getStringRef(),
+ op->getOperands(), newResultTypes, op->getAttrs());
+ for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+ state.addRegion();
+ Operation *newOp = create(state);
+ for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+ // Move all blocks of `region` into `newRegion`.
+ Region &newRegion = newOp->getRegion(index);
+ inlineRegionBefore(region, newRegion, newRegion.begin());
+ }
+
+ // Replace the original operation with the new operation.
+ SmallVector<Value> replacements(op->getNumResults(), Value());
+ unsigned nextResultIdx = 0;
+ for (unsigned i = 0, e = op->getNumResults(); i < e; ++i)
+ if (!eraseIndices[i])
+ replacements[i] = newOp->getResult(nextResultIdx++);
+ replaceOp(op, replacements);
+ return newOp;
+}
+
void RewriterBase::finalizeOpModification(Operation *op) {
// Notify the listener that the operation was modified.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 45266bc7b34ea..07911c6111043 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -205,35 +205,10 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
assert(op->getNumResults() == toErase.size() &&
"expected the number of results in `op` and the size of `toErase` to "
"be the same");
-
- std::vector<Type> newResultTypes;
- for (OpResult result : op->getResults())
- if (!toErase[result.getResultNumber()])
- newResultTypes.push_back(result.getType());
+ for (auto idx : toErase.set_bits())
+ op->getResult(idx).dropAllUses();
IRRewriter rewriter(op);
- rewriter.setInsertionPointAfter(op);
- OperationState state(op->getLoc(), op->getName().getStringRef(),
- op->getOperands(), newResultTypes, op->getAttrs());
- for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
- state.addRegion();
- Operation *newOp = rewriter.create(state);
- for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
- // Move all blocks of `region` into `newRegion`.
- Region &newRegion = newOp->getRegion(index);
- rewriter.inlineRegionBefore(region, newRegion, newRegion.begin());
- }
-
- unsigned indexOfNextNewCallOpResultToReplace = 0;
- for (auto [index, result] : llvm::enumerate(op->getResults())) {
- assert(result && "expected result to be non-null");
- if (toErase[index]) {
- result.dropAllUses();
- } else {
- result.replaceAllUsesWith(
- newOp->getResult(indexOfNextNewCallOpResultToReplace++));
- }
- }
- op->erase();
+ rewriter.eraseOpResults(op, toErase);
}
/// Convert a list of `Operand`s to a list of `OpOperand`s.
>From 777c7299cc16ff8fa69d1588612da31ebf4bd176 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 24 Dec 2025 13:26:00 +0000
Subject: [PATCH 2/2] tmp commit
simple test working
draft: do not erase IR, just replace uses
---
mlir/include/mlir/Transforms/Passes.h | 1 +
mlir/include/mlir/Transforms/Passes.td | 10 +
mlir/lib/Transforms/RemoveDeadValues.cpp | 511 +++++++------------
mlir/test/Transforms/remove-dead-values.mlir | 155 ++++--
4 files changed, 300 insertions(+), 377 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 724da009e70f1..9983944d374c5 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -42,6 +42,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_MEM2REG
#define GEN_PASS_DECL_PRINTIRPASS
#define GEN_PASS_DECL_PRINTOPSTATS
+#define GEN_PASS_DECL_REMOVEDEADVALUES
#define GEN_PASS_DECL_SCCP
#define GEN_PASS_DECL_SROA
#define GEN_PASS_DECL_STRIPDEBUGINFO
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 55addfdb693e4..fc2d60d198cd6 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -246,7 +246,17 @@ def RemoveDeadValues : Pass<"remove-dead-values"> {
do = square_and_double_of_y(5)
print(do)
```
+
+ Note: If `canonicalize` is set to "false", this pass does not remove any
+ block arguments / op results from ops that implement the
+ RegionBranchOpInterface. Instead, it just sets dead operands to
+ "ub.poison".
}];
+
+ let options = [
+ Option<"canonicalize", "canonicalize", "bool", /*default=*/"true",
+ "Canonicalize region branch ops">,
+ ];
let constructor = "mlir::createRemoveDeadValuesPass()";
let dependentDialects = ["ub::UBDialect"];
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 07911c6111043..94b25f78786f9 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -94,8 +94,11 @@ struct ResultsToCleanup {
struct OperandsToCleanup {
Operation *op;
BitVector nonLive;
- Operation *callee =
- nullptr; // Optional: For CallOpInterface ops, stores the callee function
+ // Optional: For CallOpInterface ops, stores the callee function.
+ Operation *callee = nullptr;
+ // Determines whether the operand should be replaced with a ub.poison result
+ // or erased entirely.
+ bool replaceWithPoison = false;
};
struct BlockArgsToCleanup {
@@ -199,27 +202,6 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
}
}
-/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
-/// is 1.
-static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
- assert(op->getNumResults() == toErase.size() &&
- "expected the number of results in `op` and the size of `toErase` to "
- "be the same");
- for (auto idx : toErase.set_bits())
- op->getResult(idx).dropAllUses();
- IRRewriter rewriter(op);
- rewriter.eraseOpResults(op, toErase);
-}
-
-/// Convert a list of `Operand`s to a list of `OpOperand`s.
-static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
- OpOperand *values = operands.getBase();
- SmallVector<OpOperand *> opOperands;
- for (unsigned i = 0, e = operands.size(); i < e; i++)
- opOperands.push_back(&values[i]);
- return opOperands;
-}
-
/// Process a simple operation `op` using the liveness analysis `la`.
/// If the operation has no memory effects and none of its results are live:
/// 1. Add the operation to a list for future removal, and
@@ -379,30 +361,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
///
/// Scenario 1: If the operation has no memory effects and none of its results
/// are live:
-/// (1') Enqueue all its uses for deletion.
-/// (2') Enqueue the branch itself for deletion.
+/// 1.1. Enqueue all its uses for deletion.
+/// 1.2. Enqueue the branch itself for deletion.
///
/// Scenario 2: Otherwise:
-/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
-/// results or arguments).
-/// (2) Process each of its regions.
-/// (3) Collect the uses of its unnecessary results (results forwarded from
-/// unnecessary operands
-/// or terminator operands).
-/// (4) Add these results to the deletion list.
-///
-/// Processing a region includes:
-/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
-/// from unnecessary operands
-/// or terminator operands).
-/// (b) Collecting these unnecessary arguments.
-/// (c) Collecting its unnecessary terminator operands (terminator operands
-/// forwarded to unnecessary results
-/// or arguments).
+/// 2.1. Collect block arguments and op results that we would like to keep,
+/// based on their liveness.
+/// 2.2. Find all operands that are forwarded to only dead region successor
+/// inputs. I.e., forwarded to block arguments / op results that we do
+/// not want to keep.
+/// 2.3. Enqueue all such operands for replacement with ub.poison.
///
-/// Value Flow Note: In this operation, values flow as follows:
-/// - From operands and terminator operands (successor operands)
-/// - To arguments and results (successor inputs).
+/// Note: In scenario 2, block arguments and op results are not removed.
+/// However, the IR is simplified such that canonicalization patterns can
+/// remove them later.
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
@@ -416,282 +388,76 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
// It could never be live because of this op but its liveness could have been
// attributed to something else.
- // Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
- // Mark live results of `regionBranchOp` in `liveResults`.
- auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
- };
-
- // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
- auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
- liveArgs[®ion] = regionLiveArgs;
- }
- };
-
- // Return the successors of `region` if the latter is not null. Else return
- // the successors of `regionBranchOp`.
- auto getSuccessors = [&](RegionBranchPoint point) {
- SmallVector<RegionSuccessor> successors;
- regionBranchOp.getSuccessorRegions(point, successors);
- return successors;
- };
-
- // Return the operands of `terminator` that are forwarded to `successor` if
- // the former is not null. Else return the operands of `regionBranchOp`
- // forwarded to `successor`.
- auto getForwardedOpOperands = [&](RegionBranchPoint src,
- const RegionSuccessor &successor) {
- SmallVector<OpOperand *> opOperands = operandsToOpOperands(
- regionBranchOp.getSuccessorOperands(src, successor));
- return opOperands;
- };
-
- // Mark the non-forwarded operands of `regionBranchOp` in
- // `nonForwardedOperands`.
- auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
- nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- for (OpOperand *opOperand :
- getForwardedOpOperands(RegionBranchPoint::parent(), successor))
- nonForwardedOperands.reset(opOperand->getOperandNumber());
- }
- };
-
- // Mark the non-forwarded terminator operands of the various regions of
- // `regionBranchOp` in `nonForwardedRets`.
- auto markNonForwardedReturnValues =
- [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- // TODO: this isn't correct in face of multiple terminators.
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
- region.front().getTerminator());
- nonForwardedRets[terminator] =
- BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors(terminator)) {
- for (OpOperand *opOperand : getForwardedOpOperands(
- RegionBranchPoint(terminator), successor))
- nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
- }
- }
- };
-
- // Update `valuesToKeep` (which is expected to correspond to operands or
- // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
- // `region`. When `valuesToKeep` correspond to operands, `region` is null.
- // Else, `region` is the parent region of the terminator.
- auto updateOperandsOrTerminatorOperandsToKeep =
- [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
- DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
- Operation *terminator =
- region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
- terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
-
- for (const RegionSuccessor &successor : getSuccessors(point)) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(point, successor),
- successor.getSuccessorInputs())) {
- size_t operandNum = opOperand->getOperandNumber();
- bool updateBasedOn =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
- }
- }
- };
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
- // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
- // value is modified, else, false.
- auto recomputeResultsAndArgsToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
- bool &resultsOrArgsToKeepChanged) {
- resultsOrArgsToKeepChanged = false;
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(RegionBranchPoint::parent(),
- successor),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- operandsToKeep[opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
-
- // Recompute `resultsToKeep` and `argsToKeep` based on
- // `terminatorOperandsToKeep`.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
- region.front().getTerminator());
- for (const RegionSuccessor &successor : getSuccessors(terminator)) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(RegionBranchPoint(terminator),
- successor),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- terminatorOperandsToKeep[region.back().getTerminator()]
- [opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
- }
- };
-
- // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
- // `operandsToKeep`, and `terminatorOperandsToKeep`.
- auto markValuesToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
- bool resultsOrArgsToKeepChanged = true;
- // We keep updating and recomputing the values until we reach a point
- // where they stop changing.
- while (resultsOrArgsToKeepChanged) {
- // Update the operands that need to be kept.
- updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
- resultsToKeep, argsToKeep);
-
- // Update the terminator operands that need to be kept.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- updateOperandsOrTerminatorOperandsToKeep(
- terminatorOperandsToKeep[region.back().getTerminator()],
- resultsToKeep, argsToKeep, ®ion);
- }
-
- // Recompute the results and arguments that need to be kept.
- recomputeResultsAndArgsToKeep(
- resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
- }
- };
-
- // Scenario 2.
- // At this point, we know that every non-forwarded operand of `regionBranchOp`
- // is live.
-
- // Stores the results of `regionBranchOp` that we want to keep.
- BitVector resultsToKeep;
- // Stores the mapping from regions of `regionBranchOp` to their arguments that
- // we want to keep.
- DenseMap<Region *, BitVector> argsToKeep;
- // Stores the operands of `regionBranchOp` that we want to keep.
- BitVector operandsToKeep;
- // Stores the mapping from region terminators in `regionBranchOp` to their
- // operands that we want to keep.
- DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
- // Initializing the above variables...
-
- // The live results of `regionBranchOp` definitely need to be kept.
- markLiveResults(resultsToKeep);
- // Similarly, the live arguments of the regions in `regionBranchOp` definitely
- // need to be kept.
- markLiveArgs(argsToKeep);
- // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
- // A live forwarded operand can be removed but no non-forwarded operand can be
- // removed since it "controls" the flow of data in this control flow op.
- markNonForwardedOperands(operandsToKeep);
- // Similarly, the non-forwarded terminator operands of the regions in
- // `regionBranchOp` definitely need to be kept.
- markNonForwardedReturnValues(terminatorOperandsToKeep);
-
- // Mark the values (results, arguments, operands, and terminator operands)
- // that we want to keep.
- markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep);
-
- // Do (1).
- cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
-
- // Do (2.a) and (2.b).
+ // Compute values that are alive.
+ DenseSet<Value> valuesToKeep;
+ for (Value result : regionBranchOp->getResults()) {
+ if (hasLive(result, nonLiveSet, la))
+ valuesToKeep.insert(result);
+ }
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- BitVector argsToRemove = argsToKeep[®ion].flip();
- cl.blocks.push_back({®ion.front(), argsToRemove});
- collectNonLiveValues(nonLiveSet, region.front().getArguments(),
- argsToRemove);
+ for (Value arg : region.front().getArguments()) {
+ if (hasLive(arg, nonLiveSet, la))
+ valuesToKeep.insert(arg);
+ }
}
- // Do (2.c).
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
+ // Mapping from operands to forwarded successor inputs. An operand can be
+ // forwarded to multiple successors.
+ //
+ // Example:
+ //
+ // %0 = scf.while : () -> i32 {
+ // scf.condition(...) %forwarded_value : i32
+ // } do {
+ // ^bb0(%arg0: i32):
+ // scf.yield
+ // }
+ // // No uses of %0.
+ //
+ // In the above example, %forwarded_value is forwarded to %arg0 and %0. Both
+ // %arg0 and %0 are dead, so %forwarded_value can be replaced with a
+ // ub.poison result.
+ //
+ // operandToSuccessorInputs[%forwarded_value] = {%arg0, %0}
+ //
+ RegionBranchSuccessorMapping operandToSuccessorInputs;
+ regionBranchOp.getSuccessorOperandInputMapping(operandToSuccessorInputs);
+
+ DenseMap<Operation *, BitVector> deadOperandsPerOp;
+ for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+ // If one of the successor inputs is live, the respective operand must be
+ // kept.
+ bool anyAlive = llvm::any_of(successorInputs, [&](Value input) {
+ return valuesToKeep.contains(input);
+ });
+ if (anyAlive)
continue;
- Operation *terminator = region.front().getTerminator();
- cl.operands.push_back(
- {terminator, terminatorOperandsToKeep[terminator].flip()});
+
+ // All successor inputs are dead: ub.poison can be passed as operand.
+ // Create an entry in `deadOperandsPerOp` (initialized to "false", i.e.,
+ // no "dead" op operands) if it's the first time that we are seeing an op
+ // operand for this op. Otherwise, just take the existing bit vector from
+ // the map.
+ BitVector &deadOperands =
+ deadOperandsPerOp
+ .try_emplace(opOperand->getOwner(),
+ opOperand->getOwner()->getNumOperands(), false)
+ .first->second;
+ deadOperands.set(opOperand->getOperandNumber());
}
- // Do (3) and (4).
- BitVector resultsToRemove = resultsToKeep.flip();
- collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
- resultsToRemove);
- cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
+ for (auto [op, deadOperands] : deadOperandsPerOp) {
+ cl.operands.push_back(
+ {op, deadOperands, nullptr, /*replaceWithPoison=*/true});
+ }
}
/// Steps to process a `BranchOpInterface` operation:
@@ -751,11 +517,44 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
}
}
+/// Create ub.poison ops for the given values. If a value has no uses, return
+/// an "empty" value.
+static SmallVector<Value> createPoisonedValues(OpBuilder &b,
+ ValueRange values) {
+ return llvm::map_to_vector(values, [&](Value value) {
+ if (value.use_empty())
+ return Value();
+ return ub::PoisonOp::create(b, value.getLoc(), value.getType()).getResult();
+ });
+}
+
+namespace {
+/// A listener that keeps track of ub.poison ops.
+struct TrackingListener : public RewriterBase::Listener {
+ void notifyOperationErased(Operation *op) override {
+ if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+ poisonOps.erase(poisonOp);
+ }
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+ poisonOps.insert(poisonOp);
+ }
+ DenseSet<ub::PoisonOp> poisonOps;
+};
+} // namespace
+
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
-static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
LDBG() << "Starting cleanup of dead values...";
+ // New ub.poison ops may be inserted during cleanup. Some of these ops may no
+ // longer be needed after the cleanup. A tracking listener keeps track of all
+ // new ub.poison ops, so that they can be removed again after the cleanup.
+ TrackingListener listener;
+ IRRewriter rewriter(ctx, &listener);
+
// 1. Blocks, We must remove the block arguments and successor operands before
// deleting the operation, as they may reside in the region operation.
LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
@@ -773,10 +572,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
});
// Note: Iterate from the end to make sure that that indices of not yet
// processes arguments do not change.
+ rewriter.setInsertionPointToStart(b.b);
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
- b.b->getArgument(i).dropAllUses();
+ b.b->getArgument(i).replaceAllUsesWith(
+ createPoisonedValues(rewriter, b.b->getArgument(i)).front());
b.b->eraseArgument(i);
}
}
@@ -805,22 +606,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
}
- // 3. Operations
- LDBG() << "Cleaning up " << list.operations.size() << " operations";
- for (Operation *op : list.operations) {
- LDBG() << "Erasing operation: "
- << OpWithFlags(op,
- OpPrintingFlags().skipRegions().printGenericOpForm());
- if (op->hasTrait<OpTrait::IsTerminator>()) {
- // When erasing a terminator, insert an unreachable op in its place.
- OpBuilder b(op);
- ub::UnreachableOp::create(b, op->getLoc());
- }
- op->dropAllUses();
- op->erase();
- }
-
- // 4. Functions
+ // 3. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
// Record which function arguments were erased so we can shrink call-site
// argument segments for CallOpInterface operations (e.g. ops using
@@ -837,12 +623,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
os << "]";
});
- // Drop all uses of the dead arguments.
- for (auto deadIdx : f.nonLiveArgs.set_bits())
- f.funcOp.getArgument(deadIdx).dropAllUses();
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
+ bool hasBody = !f.funcOp.getFunctionBody().empty();
+ if (hasBody) {
+ rewriter.setInsertionPointToStart(&f.funcOp.getFunctionBody().front());
+ for (auto deadIdx : f.nonLiveArgs.set_bits()) {
+ f.funcOp.getArgument(deadIdx).replaceAllUsesWith(
+ createPoisonedValues(rewriter, f.funcOp.getArgument(deadIdx))
+ .front());
+ }
+ }
if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
// Record only if we actually erased something.
if (f.nonLiveArgs.any())
@@ -851,7 +643,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
- // 5. Operands
+ // 4. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperandsToCleanup &o : list.operands) {
// Handle call-specific cleanup only when we have a cached callee reference.
@@ -896,11 +688,20 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
<< OpWithFlags(o.op,
OpPrintingFlags().skipRegions().printGenericOpForm());
});
- o.op->eraseOperands(o.nonLive);
+ if (o.replaceWithPoison) {
+ rewriter.setInsertionPoint(o.op);
+ for (auto deadIdx : o.nonLive.set_bits()) {
+ o.op->setOperand(
+ deadIdx, createPoisonedValues(rewriter, o.op->getOperand(deadIdx))
+ .front());
+ }
+ } else {
+ o.op->eraseOperands(o.nonLive);
+ }
}
}
- // 6. Results
+ // 5. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG_OS([&](raw_ostream &os) {
@@ -910,8 +711,34 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
<< OpWithFlags(r.op,
OpPrintingFlags().skipRegions().printGenericOpForm());
});
- dropUsesAndEraseResults(r.op, r.nonLive);
+ rewriter.setInsertionPoint(r.op);
+ for (auto deadIdx : r.nonLive.set_bits()) {
+ r.op->getResult(deadIdx).replaceAllUsesWith(
+ createPoisonedValues(rewriter, r.op->getResult(deadIdx)).front());
+ }
+ rewriter.eraseOpResults(r.op, r.nonLive);
}
+
+ // 6. Operations
+ LDBG() << "Cleaning up " << list.operations.size() << " operations";
+ for (Operation *op : list.operations) {
+ LDBG() << "Erasing operation: "
+ << OpWithFlags(op,
+ OpPrintingFlags().skipRegions().printGenericOpForm());
+ rewriter.setInsertionPoint(op);
+ if (op->hasTrait<OpTrait::IsTerminator>()) {
+ // When erasing a terminator, insert an unreachable op in its place.
+ ub::UnreachableOp::create(rewriter, op->getLoc());
+ }
+ rewriter.replaceOp(op, createPoisonedValues(rewriter, op->getResults()));
+ }
+
+ // 7. Remove all dead poison ops.
+ for (ub::PoisonOp poisonOp : listener.poisonOps) {
+ if (poisonOp.use_empty())
+ poisonOp.erase();
+ }
+
LDBG() << "Finished cleanup of dead values";
}
@@ -950,7 +777,27 @@ void RemoveDeadValues::runOnOperation() {
}
});
- cleanUpDeadVals(finalCleanupList);
+ MLIRContext *context = module->getContext();
+ cleanUpDeadVals(context, finalCleanupList);
+
+ if (!canonicalize)
+ return;
+
+ // Canonicalize all region branch ops.
+ SmallVector<Operation *> opsToCanonicalize;
+ module->walk([&](RegionBranchOpInterface regionBranchOp) {
+ opsToCanonicalize.push_back(regionBranchOp.getOperation());
+ });
+ RewritePatternSet owningPatterns(context);
+ for (auto *dialect : context->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(owningPatterns);
+ for (RegisteredOperationName op : context->getRegisteredOperations())
+ op.getCanonicalizationPatterns(owningPatterns, context);
+ if (failed(applyOpPatternsGreedily(opsToCanonicalize,
+ std::move(owningPatterns)))) {
+ module->emitError("greedy pattern rewrite failed to converge");
+ signalPassFailure();
+ }
}
std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 71306676d48e9..b9a883dbd524e 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE
// The IR is updated regardless of memref.global private constant
//
@@ -55,19 +56,20 @@ func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0:
// Checking that iter_args are properly handled
//
+// CHECK-CANONICALIZE-LABEL: func @cleanable_loop_iter_args_value
func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%non_live = arith.constant 0 : index
- // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
+ // CHECK-CANONICALIZE: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
%result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
- // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
+ // CHECK-CANONICALIZE: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
%new_live = arith.addi %live_arg, %i : index
- // CHECK: scf.yield [[SUM:%.+]]
+ // CHECK-CANONICALIZE: scf.yield [[SUM:%.+]]
scf.yield %new_live, %non_live_arg : index, index
}
- // CHECK: return [[RESULT]] : index
+ // CHECK-CANONICALIZE: return [[RESULT]] : index
return %result : index
}
@@ -79,7 +81,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
- func.func @main() {
+ // CHECK-LABEL: @dead_linalg_generic
+ func.func @dead_linalg_generic() {
%cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
%cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
// CHECK-NOT: arith.constant
@@ -229,18 +232,34 @@ func.func @main() -> (i32, i32) {
// anywhere else. Thus, %arg7 is also not kept in the `scf.yield` op.
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
-//
-// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
-// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+
+// CHECK-LABEL: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[p0:.*]] = ub.poison : i32
+// CHECK-NEXT: %[[while:.*]]:3 = scf.while (%{{.*}} = %[[p0]], %[[arg4:.*]] = %[[arg2]]) : (i32, i32) -> (i32, i32, i32) {
+// CHECK-NEXT: %[[add1:.*]] = arith.addi %[[arg4]], %[[arg4]] : i32
+// CHECK-NEXT: %[[p1:.*]] = ub.poison : i32
+// CHECK-NEXT: scf.condition(%[[arg0]]) %[[add1]], %[[arg4]], %[[p1]] : i32, i32, i32
// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-NEXT: scf.yield %[[live_1]] : i32
+// CHECK-NEXT: ^bb0(%{{.*}}: i32, %[[arg6:.*]]: i32, %{{.*}}: i32):
+// CHECK-NEXT: %[[add2:.*]] = arith.addi %[[arg6]], %[[arg6]] : i32
+// CHECK-NEXT: %[[p2:.*]] = ub.poison : i32
+// CHECK-NEXT: scf.yield %[[p2]], %[[add2]] : i32, i32
// CHECK-NEXT: }
-// CHECK-NEXT: return %[[live_and_non_live]]#0
+// CHECK-NEXT: return %[[while]]#0 : i32
// CHECK-NEXT: }
+
+// CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
+// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT: } do {
+// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
+// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0
+// CHECK-CANONICALIZE-NEXT: }
func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) {
%live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) {
%live_0 = arith.addi %arg4, %arg4 : i32
@@ -284,21 +303,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
-// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
-// CHECK-NEXT: %[[c0:.*]] = arith.constant 0
-// CHECK-NEXT: %[[c1:.*]] = arith.constant 1
-// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
-// CHECK-NEXT: func.call @identity() : () -> ()
-// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32
-// CHECK-NEXT: }
-// CHECK: func.func private @identity() {
-// CHECK-NEXT: return
-// CHECK-NEXT: }
+// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
+// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1
+// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> ()
+// CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE-NEXT: } do {
+// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT: scf.yield %[[arg6]], %[[arg5]] : i32, i32
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#1 : i32
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE: func.func private @identity() {
+// CHECK-CANONICALIZE-NEXT: return
+// CHECK-CANONICALIZE-NEXT: }
func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
@@ -325,17 +344,17 @@ func.func private @identity(%arg1 : i32) -> (i32) {
//
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
-// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
-// CHECK-NEXT: scf.index_switch %[[arg0]]
-// CHECK-NEXT: case 1 {
-// CHECK-NEXT: %[[c10:.*]] = arith.constant 10
-// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][]
-// CHECK-NEXT: scf.yield
-// CHECK-NEXT: }
-// CHECK-NEXT: default {
-// CHECK-NEXT: }
-// CHECK-NEXT: return
-// CHECK-NEXT: }
+// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CHECK-CANONICALIZE-NEXT: scf.index_switch %[[arg0]]
+// CHECK-CANONICALIZE-NEXT: case 1 {
+// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10
+// CHECK-CANONICALIZE-NEXT: memref.store %[[c10]], %[[arg1]][]
+// CHECK-CANONICALIZE: scf.yield
+// CHECK-CANONICALIZE-NEXT: }
+// CHECK-CANONICALIZE-NEXT: default {
+// CHECK-CANONICALIZE: }
+// CHECK-CANONICALIZE-NEXT: return
+// CHECK-CANONICALIZE-NEXT: }
func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) {
%non_live = scf.index_switch %arg0 -> i32
case 1 {
@@ -539,10 +558,10 @@ module {
}
}
-// CHECK-LABEL: func @test_zero_operands
-// CHECK: memref.alloca_scope
-// CHECK: memref.store
-// CHECK-NOT: memref.alloca_scope.return
+// CHECK-CANONICALIZE-LABEL: func @test_zero_operands
+// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT: memref.store %[[c0]]
+// CHECK-CANONICALIZE-NOT: memref.alloca_scope.return
// -----
@@ -714,3 +733,49 @@ func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64
^bb2:
return %arg1 : i64
}
+
+// -----
+
+// CHECK-LABEL: func @scf_while_dead_iter_args()
+// CHECK: %[[c5:.*]] = arith.constant 5 : i32
+// CHECK: %[[while:.*]]:2 = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> (i32, i32) {
+// CHECK: vector.print %[[arg0]]
+// CHECK: %[[cmpi:.*]] = arith.cmpi
+// CHECK: %[[p0:.*]] = ub.poison : i32
+// CHECK: scf.condition(%[[cmpi]]) %[[arg0]], %[[p0]]
+// CHECK: } do {
+// CHECK: ^bb0(%[[arg1:.*]]: i32, %[[arg2:.*]]: i32):
+// CHECK: %[[p1:.*]] = ub.poison : i32
+// CHECK: scf.yield %[[p1]]
+// CHECK: }
+// CHECK: return %[[while]]#0
+
+// CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args()
+// CHECK-CANONICALIZE: %[[c5:.*]] = arith.constant 5 : i32
+// CHECK-CANONICALIZE: %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 {
+// CHECK-CANONICALIZE: vector.print %[[arg0]]
+// CHECK-CANONICALIZE: %[[cmpi:.*]] = arith.cmpi
+// CHECK-CANONICALIZE: scf.condition(%[[cmpi]]) %[[arg0]]
+// CHECK-CANONICALIZE: } do {
+// CHECK-CANONICALIZE: ^bb0(%[[arg1:.*]]: i32):
+// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32
+// CHECK-CANONICALIZE: scf.yield %[[p0]]
+// CHECK-CANONICALIZE: }
+// CHECK-CANONICALIZE: return %[[while]]
+func.func @scf_while_dead_iter_args() -> i32 {
+ %c5 = arith.constant 5 : i32
+ %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) {
+ vector.print %arg0 : i32
+ // Note: This condition is always "false". (And the liveness analysis
+ // can figure that out.)
+ %cmp2 = arith.cmpi slt, %arg0, %c5 : i32
+ scf.condition(%cmp2) %arg0, %arg0 : i32, i32
+ } do {
+ ^bb0(%arg1: i32, %arg2: i32):
+ %x = scf.execute_region -> i32 {
+ scf.yield %arg2 : i32
+ }
+ scf.yield %x : i32
+ }
+ return %result#0 : i32
+}
More information about the Mlir-commits
mailing list