[Mlir-commits] [mlir] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults` convenience helper (PR #174152)
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 2 03:44:14 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/174152
>From 7fef1b484253e084655701846e2db4384f1852fe Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 1 Jan 2026 16:40:56 +0000
Subject: [PATCH 1/2] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults`
convenience helper
---
mlir/include/mlir/IR/PatternMatch.h | 6 +
mlir/lib/Dialect/SCF/IR/SCF.cpp | 199 +++++++----------------
mlir/lib/IR/PatternMatch.cpp | 36 ++++
mlir/lib/Transforms/RemoveDeadValues.cpp | 31 +---
4 files changed, 108 insertions(+), 164 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 1caab24ac7295..83477c79ff582 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -545,6 +545,12 @@ class RewriterBase : public OpBuilder {
/// This method erases all operations in a block.
virtual void eraseBlock(Block *block);
+ /// Erase the specified results of the given operation. Results cannot be
+ /// erased directly, so the implementation creates a new replacement
+ /// operation and erases the original operation. The new operation is
+ /// returned.
+ Operation *eraseOpResults(Operation *op, const BitVector &eraseIndices);
+
/// Inline the operations of block 'source' into block 'dest' before the given
/// position. The source block will be deleted and must have no uses.
/// 'argValues' is used to replace the block arguments of 'source'.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0a123112cf68f..a85a3a84fb0ac 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1763,89 +1763,58 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
// ForallOp::getCombiningOps(iter_arg).
//
// Based on the check we maintain the following :-
- // a. `resultToDelete` - i-th result of scf.forall that'll be
- // deleted.
- // b. `resultToReplace` - i-th result of the old scf.forall
- // whose uses will be replaced by the new scf.forall.
- // c. `newOuts` - the shared_outs' operand of the new scf.forall
- // corresponding to the i-th result with at least one use.
- SetVector<OpResult> resultToDelete;
- SmallVector<Value> resultToReplace;
+ // a. op results, block arguments, outputs to delete
+ // b. new outputs (i.e., outputs to retain)
+ SmallVector<Value> resultsToDelete;
+ SmallVector<Value> outsToDelete;
+ SmallVector<BlockArgument> blockArgsToDelete;
SmallVector<Value> newOuts;
+ BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
+ BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
+ false);
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
- resultToDelete.insert(result);
+ resultsToDelete.push_back(result);
+ outsToDelete.push_back(opOperand->get());
+ blockArgsToDelete.push_back(blockArg);
+ resultIndicesToDelete[result.getResultNumber()] = true;
+ blockIndicesToDelete[blockArg.getArgNumber()] = true;
} else {
- resultToReplace.push_back(result);
newOuts.push_back(opOperand->get());
}
}
// Return early if all results of scf.forall have at least one use and being
// modified within the loop.
- if (resultToDelete.empty())
+ if (resultsToDelete.empty())
return failure();
- // Step 2: For the the i-th result, do the following :-
- // a. Fetch the corresponding BlockArgument.
- // b. Look for store ops (currently tensor.parallel_insert_slice)
- // with the BlockArgument as its destination operand.
- // c. Remove the operations fetched in b.
- for (OpResult result : resultToDelete) {
- OpOperand *opOperand = forallOp.getTiedOpOperand(result);
- BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ // Step 2: Erase combining ops and replace uses of deleted results and
+ // block arguments with the corresponding outputs.
+ for (auto blockArg : blockArgsToDelete) {
SmallVector<Operation *> combiningOps =
forallOp.getCombiningOps(blockArg);
for (Operation *combiningOp : combiningOps)
rewriter.eraseOp(combiningOp);
}
-
- // Step 3. Create a new scf.forall op with the new shared_outs' operands
- // fetched earlier
- auto newForallOp = scf::ForallOp::create(
- rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
- forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
- forallOp.getMapping(),
- /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
- // Step 4. Merge the block of the old scf.forall into the newly created
- // scf.forall using the new set of arguments.
- Block *loopBody = forallOp.getBody();
- Block *newLoopBody = newForallOp.getBody();
- ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
- // Form initial new bbArg list with just the control operands of the new
- // scf.forall op.
- SmallVector<Value> newBlockArgs =
- llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
- [](BlockArgument b) -> Value { return b; });
- Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
- unsigned index = 0;
- // Take the new corresponding bbArg if the old bbArg was used as a
- // destination in the in_parallel op. For all other bbArgs, use the
- // corresponding init_arg from the old scf.forall op.
- for (OpResult result : forallOp.getResults()) {
- if (resultToDelete.count(result)) {
- newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
- } else {
- newBlockArgs.push_back(newSharedOutsArgs[index++]);
- }
+ for (auto [blockArg, out, result] :
+ llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+ rewriter.replaceAllUsesWith(blockArg, out);
+ rewriter.replaceAllUsesWith(result, out);
}
- rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
-
- // Step 5. Replace the uses of result of old scf.forall with that of the new
- // scf.forall.
- for (auto &&[oldResult, newResult] :
- llvm::zip(resultToReplace, newForallOp->getResults()))
- rewriter.replaceAllUsesWith(oldResult, newResult);
-
- // Step 6. Replace the uses of those values that either has no use or are
- // not being modified within the loop with the corresponding
- // OpOperand.
- for (OpResult oldResult : resultToDelete)
- rewriter.replaceAllUsesWith(oldResult,
- forallOp.getTiedOpOperand(oldResult)->get());
+ // TODO: There is no rewriter API for erasing block arguments.
+ rewriter.modifyOpInPlace(forallOp, [&]() {
+ forallOp.getBody()->eraseArguments(blockIndicesToDelete);
+ });
+
+ // Step 3. Create a new scf.forall op with only the shared_outs/results
+ // that should be retained.
+ auto newForallOp = cast<scf::ForallOp>(
+ rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
+ newForallOp.getOutputsMutable().assign(newOuts);
+
return success();
}
};
@@ -2413,53 +2382,27 @@ namespace {
struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
- void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
- PatternRewriter &rewriter) const {
- // Move all operations to the destination block.
- rewriter.mergeBlocks(source, dest);
- // Replace the yield op by one that returns only the used values.
- auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
- SmallVector<Value, 4> usedOperands;
- llvm::transform(usedResults, std::back_inserter(usedOperands),
- [&](OpResult result) {
- return yieldOp.getOperand(result.getResultNumber());
- });
- rewriter.modifyOpInPlace(yieldOp,
- [&]() { yieldOp->setOperands(usedOperands); });
- }
-
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
- // Compute the list of used results.
- SmallVector<OpResult, 4> usedResults;
- llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
- [](OpResult result) { return !result.use_empty(); });
-
- // Replace the operation if only a subset of its results have uses.
- if (usedResults.size() == op.getNumResults())
- return failure();
-
- // Compute the result types of the replacement operation.
- SmallVector<Type, 4> newTypes;
- llvm::transform(usedResults, std::back_inserter(newTypes),
- [](OpResult result) { return result.getType(); });
+ // Compute the list of unused results.
+ BitVector toErase(op.getNumResults(), false);
+ for (auto [idx, result] : llvm::enumerate(op.getResults()))
+ if (result.use_empty())
+ toErase[idx] = true;
+ if (toErase.none())
+ return rewriter.notifyMatchFailure(op, "no results to erase");
+
+ // Erase results.
+ auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase));
+
+ // Erase operands.
+ rewriter.modifyOpInPlace(newOp.thenYield(), [&]() {
+ newOp.thenYield()->eraseOperands(toErase);
+ });
+ rewriter.modifyOpInPlace(newOp.elseYield(), [&]() {
+ newOp.elseYield()->eraseOperands(toErase);
+ });
- // Create a replacement operation with empty then and else regions.
- auto newOp =
- IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
- rewriter.createBlock(&newOp.getThenRegion());
- rewriter.createBlock(&newOp.getElseRegion());
-
- // Move the bodies and replace the terminators (note there is a then and
- // an else region since the operation returns results).
- transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
- transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
-
- // Replace the operation by the new one.
- SmallVector<Value, 4> repResults(op.getNumResults());
- for (const auto &en : llvm::enumerate(usedResults))
- repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
- rewriter.replaceOp(op, repResults);
return success();
}
};
@@ -4720,43 +4663,27 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
PatternRewriter &rewriter) const override {
// Find dead results.
BitVector deadResults(op.getNumResults(), false);
- SmallVector<Type> newResultTypes;
- for (auto [idx, result] : llvm::enumerate(op.getResults())) {
- if (!result.use_empty()) {
- newResultTypes.push_back(result.getType());
- } else {
+ for (auto [idx, result] : llvm::enumerate(op.getResults()))
+ if (result.use_empty())
deadResults[idx] = true;
- }
- }
if (!deadResults.any())
return rewriter.notifyMatchFailure(op, "no dead results to fold");
- // Create new op without dead results and inline case regions.
- auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
- op.getArg(), op.getCases(),
- op.getCaseRegions().size());
- auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
- rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
- // Remove respective operands from yield op.
- Operation *terminator = newRegion.front().getTerminator();
+ // Erase dead results.
+ auto newOp =
+ cast<scf::IndexSwitchOp>(rewriter.eraseOpResults(op, deadResults));
+
+ // Erase operands from yield ops.
+ auto updateCaseRegion = [&](Region ®ion) {
+ Operation *terminator = region.front().getTerminator();
assert(isa<YieldOp>(terminator) && "expected yield op");
rewriter.modifyOpInPlace(
terminator, [&]() { terminator->eraseOperands(deadResults); });
};
- for (auto [oldRegion, newRegion] :
- llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
- inlineCaseRegion(oldRegion, newRegion);
- inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
- // Replace op with new op.
- SmallVector<Value> newResults(op.getNumResults(), Value());
- unsigned nextNewResult = 0;
- for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
- if (deadResults[idx])
- continue;
- newResults[idx] = newOp.getResult(nextNewResult++);
- }
- rewriter.replaceOp(op, newResults);
+ updateCaseRegion(newOp.getDefaultRegion());
+ for (Region &caseRegion : newOp.getCaseRegions())
+ updateCaseRegion(caseRegion);
+
return success();
}
};
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 226e4e518d3e0..532697ea432f3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,6 +244,42 @@ void RewriterBase::eraseBlock(Block *block) {
block->erase();
}
+Operation *RewriterBase::eraseOpResults(Operation *op,
+ const BitVector &eraseIndices) {
+ assert(op->getNumResults() == eraseIndices.size() &&
+ "number of op results and bitvector size must match");
+
+ // Gather new result types.
+ SmallVector<Type> newResultTypes;
+ newResultTypes.reserve(op->getNumResults() - eraseIndices.count());
+ for (OpResult result : op->getResults())
+ if (!eraseIndices[result.getResultNumber()])
+ newResultTypes.push_back(result.getType());
+
+ // Create a new operation and inline all regions.
+ InsertionGuard g(*this);
+ setInsertionPoint(op);
+ OperationState state(op->getLoc(), op->getName().getStringRef(),
+ op->getOperands(), newResultTypes, op->getAttrs());
+ for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+ state.addRegion();
+ Operation *newOp = create(state);
+ for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+ // Move all blocks of `region` into `newRegion`.
+ Region &newRegion = newOp->getRegion(index);
+ inlineRegionBefore(region, newRegion, newRegion.begin());
+ }
+
+ // Replace the original operation with the new operation.
+ SmallVector<Value> replacements(op->getNumResults(), Value());
+ unsigned nextResultIdx = 0;
+ for (auto i : llvm::seq<unsigned>(0, op->getNumResults()))
+ if (!eraseIndices[i])
+ replacements[i] = newOp->getResult(nextResultIdx++);
+ replaceOp(op, replacements);
+ return newOp;
+}
+
void RewriterBase::finalizeOpModification(Operation *op) {
// Notify the listener that the operation was modified.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 45266bc7b34ea..07911c6111043 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -205,35 +205,10 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
assert(op->getNumResults() == toErase.size() &&
"expected the number of results in `op` and the size of `toErase` to "
"be the same");
-
- std::vector<Type> newResultTypes;
- for (OpResult result : op->getResults())
- if (!toErase[result.getResultNumber()])
- newResultTypes.push_back(result.getType());
+ for (auto idx : toErase.set_bits())
+ op->getResult(idx).dropAllUses();
IRRewriter rewriter(op);
- rewriter.setInsertionPointAfter(op);
- OperationState state(op->getLoc(), op->getName().getStringRef(),
- op->getOperands(), newResultTypes, op->getAttrs());
- for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
- state.addRegion();
- Operation *newOp = rewriter.create(state);
- for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
- // Move all blocks of `region` into `newRegion`.
- Region &newRegion = newOp->getRegion(index);
- rewriter.inlineRegionBefore(region, newRegion, newRegion.begin());
- }
-
- unsigned indexOfNextNewCallOpResultToReplace = 0;
- for (auto [index, result] : llvm::enumerate(op->getResults())) {
- assert(result && "expected result to be non-null");
- if (toErase[index]) {
- result.dropAllUses();
- } else {
- result.replaceAllUsesWith(
- newOp->getResult(indexOfNextNewCallOpResultToReplace++));
- }
- }
- op->erase();
+ rewriter.eraseOpResults(op, toErase);
}
/// Convert a list of `Operand`s to a list of `OpOperand`s.
>From dbc28c360938e2bb0b495cf5d2fce8b3ae6d552e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 2 Jan 2026 11:42:35 +0000
Subject: [PATCH 2/2] address comments
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++--
mlir/lib/IR/PatternMatch.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a85a3a84fb0ac..8803a6d136f7a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1799,8 +1799,8 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
for (Operation *combiningOp : combiningOps)
rewriter.eraseOp(combiningOp);
}
- for (auto [blockArg, out, result] :
- llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+ for (auto [blockArg, result, out] :
+ llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
rewriter.replaceAllUsesWith(blockArg, out);
rewriter.replaceAllUsesWith(result, out);
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 532697ea432f3..cd067f2cc25b3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -261,7 +261,7 @@ Operation *RewriterBase::eraseOpResults(Operation *op,
setInsertionPoint(op);
OperationState state(op->getLoc(), op->getName().getStringRef(),
op->getOperands(), newResultTypes, op->getAttrs());
- for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+ for ([[maybe_unused]] auto i : llvm::seq<unsigned>(0, op->getNumRegions()))
state.addRegion();
Operation *newOp = create(state);
for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
More information about the Mlir-commits
mailing list