[Mlir-commits] [mlir] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults` convenience helper (PR #174152)
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 2 01:36:45 PST 2026
================
@@ -1763,89 +1763,55 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
// ForallOp::getCombiningOps(iter_arg).
//
// Based on the check we maintain the following :-
- // a. `resultToDelete` - i-th result of scf.forall that'll be
- // deleted.
- // b. `resultToReplace` - i-th result of the old scf.forall
- // whose uses will be replaced by the new scf.forall.
- // c. `newOuts` - the shared_outs' operand of the new scf.forall
- // corresponding to the i-th result with at least one use.
- SetVector<OpResult> resultToDelete;
- SmallVector<Value> resultToReplace;
+ // a. op results, block arguments, outputs to delete
+ // b. new outputs (i.e., outputs to retain)
+ SmallVector<Value> resultsToDelete;
+ SmallVector<Value> outsToDelete;
+ SmallVector<BlockArgument> blockArgsToDelete;
SmallVector<Value> newOuts;
+ BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
+ BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
+ false);
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
- resultToDelete.insert(result);
+ resultsToDelete.push_back(result);
+ outsToDelete.push_back(opOperand->get());
+ blockArgsToDelete.push_back(blockArg);
+ resultIndicesToDelete[result.getResultNumber()] = true;
+ blockIndicesToDelete[blockArg.getArgNumber()] = true;
} else {
- resultToReplace.push_back(result);
newOuts.push_back(opOperand->get());
}
}
// Return early if all results of scf.forall have at least one use and being
// modified within the loop.
- if (resultToDelete.empty())
+ if (resultsToDelete.empty())
return failure();
- // Step 2: For the the i-th result, do the following :-
- // a. Fetch the corresponding BlockArgument.
- // b. Look for store ops (currently tensor.parallel_insert_slice)
- // with the BlockArgument as its destination operand.
- // c. Remove the operations fetched in b.
- for (OpResult result : resultToDelete) {
- OpOperand *opOperand = forallOp.getTiedOpOperand(result);
- BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+ // Step 2: Erase combining ops and replace uses of deleted results and
+ // block arguments with the corresponding outputs.
+ for (auto blockArg : blockArgsToDelete) {
SmallVector<Operation *> combiningOps =
forallOp.getCombiningOps(blockArg);
for (Operation *combiningOp : combiningOps)
rewriter.eraseOp(combiningOp);
}
-
- // Step 3. Create a new scf.forall op with the new shared_outs' operands
- // fetched earlier
- auto newForallOp = scf::ForallOp::create(
- rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
- forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
- forallOp.getMapping(),
- /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
- // Step 4. Merge the block of the old scf.forall into the newly created
- // scf.forall using the new set of arguments.
- Block *loopBody = forallOp.getBody();
- Block *newLoopBody = newForallOp.getBody();
- ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
- // Form initial new bbArg list with just the control operands of the new
- // scf.forall op.
- SmallVector<Value> newBlockArgs =
- llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
- [](BlockArgument b) -> Value { return b; });
- Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
- unsigned index = 0;
- // Take the new corresponding bbArg if the old bbArg was used as a
- // destination in the in_parallel op. For all other bbArgs, use the
- // corresponding init_arg from the old scf.forall op.
- for (OpResult result : forallOp.getResults()) {
- if (resultToDelete.count(result)) {
- newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
- } else {
- newBlockArgs.push_back(newSharedOutsArgs[index++]);
- }
+ for (auto [blockArg, out, result] :
+ llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+ rewriter.replaceAllUsesWith(blockArg, out);
+ rewriter.replaceAllUsesWith(result, out);
}
- rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
-
- // Step 5. Replace the uses of result of old scf.forall with that of the new
- // scf.forall.
- for (auto &&[oldResult, newResult] :
- llvm::zip(resultToReplace, newForallOp->getResults()))
- rewriter.replaceAllUsesWith(oldResult, newResult);
-
- // Step 6. Replace the uses of those values that either has no use or are
- // not being modified within the loop with the corresponding
- // OpOperand.
- for (OpResult oldResult : resultToDelete)
- rewriter.replaceAllUsesWith(oldResult,
- forallOp.getTiedOpOperand(oldResult)->get());
+ forallOp.getBody()->eraseArguments(blockIndicesToDelete);
----------------
matthias-springer wrote:
To set a good example, I wrapped the `eraseArguments` call in `modifyOpInPlace`, which is what we typically do when modifying blocks.
https://github.com/llvm/llvm-project/pull/174152
More information about the Mlir-commits
mailing list