[Mlir-commits] [mlir] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults` convenience helper (PR #174152)

Matthias Springer llvmlistbot at llvm.org
Fri Jan 2 01:36:45 PST 2026


================
@@ -1763,89 +1763,55 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
     //            ForallOp::getCombiningOps(iter_arg).
     //
     //         Based on the check we maintain the following :-
-    //         a. `resultToDelete` - i-th result of scf.forall that'll be
-    //            deleted.
-    //         b. `resultToReplace` - i-th result of the old scf.forall
-    //            whose uses will be replaced by the new scf.forall.
-    //         c. `newOuts` - the shared_outs' operand of the new scf.forall
-    //            corresponding to the i-th result with at least one use.
-    SetVector<OpResult> resultToDelete;
-    SmallVector<Value> resultToReplace;
+    //         a. op results, block arguments, outputs to delete
+    //         b. new outputs (i.e., outputs to retain)
+    SmallVector<Value> resultsToDelete;
+    SmallVector<Value> outsToDelete;
+    SmallVector<BlockArgument> blockArgsToDelete;
     SmallVector<Value> newOuts;
+    BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
+    BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
+                                   false);
     for (OpResult result : forallOp.getResults()) {
       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
       if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
-        resultToDelete.insert(result);
+        resultsToDelete.push_back(result);
+        outsToDelete.push_back(opOperand->get());
+        blockArgsToDelete.push_back(blockArg);
+        resultIndicesToDelete[result.getResultNumber()] = true;
+        blockIndicesToDelete[blockArg.getArgNumber()] = true;
       } else {
-        resultToReplace.push_back(result);
         newOuts.push_back(opOperand->get());
       }
     }
 
     // Return early if all results of scf.forall have at least one use and being
     // modified within the loop.
-    if (resultToDelete.empty())
+    if (resultsToDelete.empty())
       return failure();
 
-    // Step 2: For the the i-th result, do the following :-
-    //         a. Fetch the corresponding BlockArgument.
-    //         b. Look for store ops (currently tensor.parallel_insert_slice)
-    //            with the BlockArgument as its destination operand.
-    //         c. Remove the operations fetched in b.
-    for (OpResult result : resultToDelete) {
-      OpOperand *opOperand = forallOp.getTiedOpOperand(result);
-      BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+    // Step 2: Erase combining ops and replace uses of deleted results and
+    //         block arguments with the corresponding outputs.
+    for (auto blockArg : blockArgsToDelete) {
       SmallVector<Operation *> combiningOps =
           forallOp.getCombiningOps(blockArg);
       for (Operation *combiningOp : combiningOps)
         rewriter.eraseOp(combiningOp);
     }
-
-    // Step 3. Create a new scf.forall op with the new shared_outs' operands
-    //         fetched earlier
-    auto newForallOp = scf::ForallOp::create(
-        rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
-        forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
-        forallOp.getMapping(),
-        /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
-    // Step 4. Merge the block of the old scf.forall into the newly created
-    //         scf.forall using the new set of arguments.
-    Block *loopBody = forallOp.getBody();
-    Block *newLoopBody = newForallOp.getBody();
-    ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
-    // Form initial new bbArg list with just the control operands of the new
-    // scf.forall op.
-    SmallVector<Value> newBlockArgs =
-        llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
-                            [](BlockArgument b) -> Value { return b; });
-    Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
-    unsigned index = 0;
-    // Take the new corresponding bbArg if the old bbArg was used as a
-    // destination in the in_parallel op. For all other bbArgs, use the
-    // corresponding init_arg from the old scf.forall op.
-    for (OpResult result : forallOp.getResults()) {
-      if (resultToDelete.count(result)) {
-        newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
-      } else {
-        newBlockArgs.push_back(newSharedOutsArgs[index++]);
-      }
+    for (auto [blockArg, out, result] :
+         llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+      rewriter.replaceAllUsesWith(blockArg, out);
+      rewriter.replaceAllUsesWith(result, out);
     }
-    rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
-
-    // Step 5. Replace the uses of result of old scf.forall with that of the new
-    //         scf.forall.
-    for (auto &&[oldResult, newResult] :
-         llvm::zip(resultToReplace, newForallOp->getResults()))
-      rewriter.replaceAllUsesWith(oldResult, newResult);
-
-    // Step 6. Replace the uses of those values that either has no use or are
-    //         not being modified within the loop with the corresponding
-    //         OpOperand.
-    for (OpResult oldResult : resultToDelete)
-      rewriter.replaceAllUsesWith(oldResult,
-                                  forallOp.getTiedOpOperand(oldResult)->get());
+    forallOp.getBody()->eraseArguments(blockIndicesToDelete);
----------------
matthias-springer wrote:

To set a good example, I wrapped the `eraseArguments` call in `modifyOpInPlace`, which is what we typically do when modifying blocks.

https://github.com/llvm/llvm-project/pull/174152


More information about the Mlir-commits mailing list