[Mlir-commits] [mlir] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults` convenience helper (PR #174152)

Matthias Springer llvmlistbot at llvm.org
Fri Jan 2 03:44:14 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/174152

>From 7fef1b484253e084655701846e2db4384f1852fe Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 1 Jan 2026 16:40:56 +0000
Subject: [PATCH 1/2] [mlir][IR][NFC] Add `RewriterBase::eraseOpResults`
 convenience helper

---
 mlir/include/mlir/IR/PatternMatch.h      |   6 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp          | 199 +++++++----------------
 mlir/lib/IR/PatternMatch.cpp             |  36 ++++
 mlir/lib/Transforms/RemoveDeadValues.cpp |  31 +---
 4 files changed, 108 insertions(+), 164 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 1caab24ac7295..83477c79ff582 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -545,6 +545,12 @@ class RewriterBase : public OpBuilder {
   /// This method erases all operations in a block.
   virtual void eraseBlock(Block *block);
 
+  /// Erase the specified results of the given operation. Results cannot be
+  /// erased directly, so the implementation creates a new replacement
+  /// operation and erases the original operation. The new operation is
+  /// returned.
+  Operation *eraseOpResults(Operation *op, const BitVector &eraseIndices);
+
   /// Inline the operations of block 'source' into block 'dest' before the given
   /// position. The source block will be deleted and must have no uses.
   /// 'argValues' is used to replace the block arguments of 'source'.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0a123112cf68f..a85a3a84fb0ac 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1763,89 +1763,58 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
     //            ForallOp::getCombiningOps(iter_arg).
     //
     //         Based on the check we maintain the following :-
-    //         a. `resultToDelete` - i-th result of scf.forall that'll be
-    //            deleted.
-    //         b. `resultToReplace` - i-th result of the old scf.forall
-    //            whose uses will be replaced by the new scf.forall.
-    //         c. `newOuts` - the shared_outs' operand of the new scf.forall
-    //            corresponding to the i-th result with at least one use.
-    SetVector<OpResult> resultToDelete;
-    SmallVector<Value> resultToReplace;
+    //         a. op results, block arguments, outputs to delete
+    //         b. new outputs (i.e., outputs to retain)
+    SmallVector<Value> resultsToDelete;
+    SmallVector<Value> outsToDelete;
+    SmallVector<BlockArgument> blockArgsToDelete;
     SmallVector<Value> newOuts;
+    BitVector resultIndicesToDelete(forallOp.getNumResults(), false);
+    BitVector blockIndicesToDelete(forallOp.getBody()->getNumArguments(),
+                                   false);
     for (OpResult result : forallOp.getResults()) {
       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
       if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
-        resultToDelete.insert(result);
+        resultsToDelete.push_back(result);
+        outsToDelete.push_back(opOperand->get());
+        blockArgsToDelete.push_back(blockArg);
+        resultIndicesToDelete[result.getResultNumber()] = true;
+        blockIndicesToDelete[blockArg.getArgNumber()] = true;
       } else {
-        resultToReplace.push_back(result);
         newOuts.push_back(opOperand->get());
       }
     }
 
     // Return early if all results of scf.forall have at least one use and being
     // modified within the loop.
-    if (resultToDelete.empty())
+    if (resultsToDelete.empty())
       return failure();
 
-    // Step 2: For the the i-th result, do the following :-
-    //         a. Fetch the corresponding BlockArgument.
-    //         b. Look for store ops (currently tensor.parallel_insert_slice)
-    //            with the BlockArgument as its destination operand.
-    //         c. Remove the operations fetched in b.
-    for (OpResult result : resultToDelete) {
-      OpOperand *opOperand = forallOp.getTiedOpOperand(result);
-      BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+    // Step 2: Erase combining ops and replace uses of deleted results and
+    //         block arguments with the corresponding outputs.
+    for (auto blockArg : blockArgsToDelete) {
       SmallVector<Operation *> combiningOps =
           forallOp.getCombiningOps(blockArg);
       for (Operation *combiningOp : combiningOps)
         rewriter.eraseOp(combiningOp);
     }
-
-    // Step 3. Create a new scf.forall op with the new shared_outs' operands
-    //         fetched earlier
-    auto newForallOp = scf::ForallOp::create(
-        rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
-        forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
-        forallOp.getMapping(),
-        /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
-    // Step 4. Merge the block of the old scf.forall into the newly created
-    //         scf.forall using the new set of arguments.
-    Block *loopBody = forallOp.getBody();
-    Block *newLoopBody = newForallOp.getBody();
-    ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
-    // Form initial new bbArg list with just the control operands of the new
-    // scf.forall op.
-    SmallVector<Value> newBlockArgs =
-        llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
-                            [](BlockArgument b) -> Value { return b; });
-    Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
-    unsigned index = 0;
-    // Take the new corresponding bbArg if the old bbArg was used as a
-    // destination in the in_parallel op. For all other bbArgs, use the
-    // corresponding init_arg from the old scf.forall op.
-    for (OpResult result : forallOp.getResults()) {
-      if (resultToDelete.count(result)) {
-        newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
-      } else {
-        newBlockArgs.push_back(newSharedOutsArgs[index++]);
-      }
+    for (auto [blockArg, out, result] :
+         llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+      rewriter.replaceAllUsesWith(blockArg, out);
+      rewriter.replaceAllUsesWith(result, out);
     }
-    rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
-
-    // Step 5. Replace the uses of result of old scf.forall with that of the new
-    //         scf.forall.
-    for (auto &&[oldResult, newResult] :
-         llvm::zip(resultToReplace, newForallOp->getResults()))
-      rewriter.replaceAllUsesWith(oldResult, newResult);
-
-    // Step 6. Replace the uses of those values that either has no use or are
-    //         not being modified within the loop with the corresponding
-    //         OpOperand.
-    for (OpResult oldResult : resultToDelete)
-      rewriter.replaceAllUsesWith(oldResult,
-                                  forallOp.getTiedOpOperand(oldResult)->get());
+    // TODO: There is no rewriter API for erasing block arguments.
+    rewriter.modifyOpInPlace(forallOp, [&]() {
+      forallOp.getBody()->eraseArguments(blockIndicesToDelete);
+    });
+
+    // Step 3. Create a new scf.forall op with only the shared_outs/results
+    //         that should be retained.
+    auto newForallOp = cast<scf::ForallOp>(
+        rewriter.eraseOpResults(forallOp, resultIndicesToDelete));
+    newForallOp.getOutputsMutable().assign(newOuts);
+
     return success();
   }
 };
@@ -2413,53 +2382,27 @@ namespace {
 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
 
-  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
-                    PatternRewriter &rewriter) const {
-    // Move all operations to the destination block.
-    rewriter.mergeBlocks(source, dest);
-    // Replace the yield op by one that returns only the used values.
-    auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
-    SmallVector<Value, 4> usedOperands;
-    llvm::transform(usedResults, std::back_inserter(usedOperands),
-                    [&](OpResult result) {
-                      return yieldOp.getOperand(result.getResultNumber());
-                    });
-    rewriter.modifyOpInPlace(yieldOp,
-                             [&]() { yieldOp->setOperands(usedOperands); });
-  }
-
   LogicalResult matchAndRewrite(IfOp op,
                                 PatternRewriter &rewriter) const override {
-    // Compute the list of used results.
-    SmallVector<OpResult, 4> usedResults;
-    llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
-                  [](OpResult result) { return !result.use_empty(); });
-
-    // Replace the operation if only a subset of its results have uses.
-    if (usedResults.size() == op.getNumResults())
-      return failure();
-
-    // Compute the result types of the replacement operation.
-    SmallVector<Type, 4> newTypes;
-    llvm::transform(usedResults, std::back_inserter(newTypes),
-                    [](OpResult result) { return result.getType(); });
+    // Compute the list of unused results.
+    BitVector toErase(op.getNumResults(), false);
+    for (auto [idx, result] : llvm::enumerate(op.getResults()))
+      if (result.use_empty())
+        toErase[idx] = true;
+    if (toErase.none())
+      return rewriter.notifyMatchFailure(op, "no results to erase");
+
+    // Erase results.
+    auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase));
+
+    // Erase operands.
+    rewriter.modifyOpInPlace(newOp.thenYield(), [&]() {
+      newOp.thenYield()->eraseOperands(toErase);
+    });
+    rewriter.modifyOpInPlace(newOp.elseYield(), [&]() {
+      newOp.elseYield()->eraseOperands(toErase);
+    });
 
-    // Create a replacement operation with empty then and else regions.
-    auto newOp =
-        IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
-    rewriter.createBlock(&newOp.getThenRegion());
-    rewriter.createBlock(&newOp.getElseRegion());
-
-    // Move the bodies and replace the terminators (note there is a then and
-    // an else region since the operation returns results).
-    transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
-    transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
-
-    // Replace the operation by the new one.
-    SmallVector<Value, 4> repResults(op.getNumResults());
-    for (const auto &en : llvm::enumerate(usedResults))
-      repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
-    rewriter.replaceOp(op, repResults);
     return success();
   }
 };
@@ -4720,43 +4663,27 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
                                 PatternRewriter &rewriter) const override {
     // Find dead results.
     BitVector deadResults(op.getNumResults(), false);
-    SmallVector<Type> newResultTypes;
-    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
-      if (!result.use_empty()) {
-        newResultTypes.push_back(result.getType());
-      } else {
+    for (auto [idx, result] : llvm::enumerate(op.getResults()))
+      if (result.use_empty())
         deadResults[idx] = true;
-      }
-    }
     if (!deadResults.any())
       return rewriter.notifyMatchFailure(op, "no dead results to fold");
 
-    // Create new op without dead results and inline case regions.
-    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
-                                       op.getArg(), op.getCases(),
-                                       op.getCaseRegions().size());
-    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
-      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
-      // Remove respective operands from yield op.
-      Operation *terminator = newRegion.front().getTerminator();
+    // Erase dead results.
+    auto newOp =
+        cast<scf::IndexSwitchOp>(rewriter.eraseOpResults(op, deadResults));
+
+    // Erase operands from yield ops.
+    auto updateCaseRegion = [&](Region &region) {
+      Operation *terminator = region.front().getTerminator();
       assert(isa<YieldOp>(terminator) && "expected yield op");
       rewriter.modifyOpInPlace(
           terminator, [&]() { terminator->eraseOperands(deadResults); });
     };
-    for (auto [oldRegion, newRegion] :
-         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
-      inlineCaseRegion(oldRegion, newRegion);
-    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
-    // Replace op with new op.
-    SmallVector<Value> newResults(op.getNumResults(), Value());
-    unsigned nextNewResult = 0;
-    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
-      if (deadResults[idx])
-        continue;
-      newResults[idx] = newOp.getResult(nextNewResult++);
-    }
-    rewriter.replaceOp(op, newResults);
+    updateCaseRegion(newOp.getDefaultRegion());
+    for (Region &caseRegion : newOp.getCaseRegions())
+      updateCaseRegion(caseRegion);
+
     return success();
   }
 };
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 226e4e518d3e0..532697ea432f3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,6 +244,42 @@ void RewriterBase::eraseBlock(Block *block) {
   block->erase();
 }
 
+Operation *RewriterBase::eraseOpResults(Operation *op,
+                                        const BitVector &eraseIndices) {
+  assert(op->getNumResults() == eraseIndices.size() &&
+         "number of op results and bitvector size must match");
+
+  // Gather new result types.
+  SmallVector<Type> newResultTypes;
+  newResultTypes.reserve(op->getNumResults() - eraseIndices.count());
+  for (OpResult result : op->getResults())
+    if (!eraseIndices[result.getResultNumber()])
+      newResultTypes.push_back(result.getType());
+
+  // Create a new operation and inline all regions.
+  InsertionGuard g(*this);
+  setInsertionPoint(op);
+  OperationState state(op->getLoc(), op->getName().getStringRef(),
+                       op->getOperands(), newResultTypes, op->getAttrs());
+  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+    state.addRegion();
+  Operation *newOp = create(state);
+  for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+    // Move all blocks of `region` into `newRegion`.
+    Region &newRegion = newOp->getRegion(index);
+    inlineRegionBefore(region, newRegion, newRegion.begin());
+  }
+
+  // Replace the original operation with the new operation.
+  SmallVector<Value> replacements(op->getNumResults(), Value());
+  unsigned nextResultIdx = 0;
+  for (auto i : llvm::seq<unsigned>(0, op->getNumResults()))
+    if (!eraseIndices[i])
+      replacements[i] = newOp->getResult(nextResultIdx++);
+  replaceOp(op, replacements);
+  return newOp;
+}
+
 void RewriterBase::finalizeOpModification(Operation *op) {
   // Notify the listener that the operation was modified.
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 45266bc7b34ea..07911c6111043 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -205,35 +205,10 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
   assert(op->getNumResults() == toErase.size() &&
          "expected the number of results in `op` and the size of `toErase` to "
          "be the same");
-
-  std::vector<Type> newResultTypes;
-  for (OpResult result : op->getResults())
-    if (!toErase[result.getResultNumber()])
-      newResultTypes.push_back(result.getType());
+  for (auto idx : toErase.set_bits())
+    op->getResult(idx).dropAllUses();
   IRRewriter rewriter(op);
-  rewriter.setInsertionPointAfter(op);
-  OperationState state(op->getLoc(), op->getName().getStringRef(),
-                       op->getOperands(), newResultTypes, op->getAttrs());
-  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
-    state.addRegion();
-  Operation *newOp = rewriter.create(state);
-  for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
-    // Move all blocks of `region` into `newRegion`.
-    Region &newRegion = newOp->getRegion(index);
-    rewriter.inlineRegionBefore(region, newRegion, newRegion.begin());
-  }
-
-  unsigned indexOfNextNewCallOpResultToReplace = 0;
-  for (auto [index, result] : llvm::enumerate(op->getResults())) {
-    assert(result && "expected result to be non-null");
-    if (toErase[index]) {
-      result.dropAllUses();
-    } else {
-      result.replaceAllUsesWith(
-          newOp->getResult(indexOfNextNewCallOpResultToReplace++));
-    }
-  }
-  op->erase();
+  rewriter.eraseOpResults(op, toErase);
 }
 
 /// Convert a list of `Operand`s to a list of `OpOperand`s.

>From dbc28c360938e2bb0b495cf5d2fce8b3ae6d552e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 2 Jan 2026 11:42:35 +0000
Subject: [PATCH 2/2] address comments

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++--
 mlir/lib/IR/PatternMatch.cpp    | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a85a3a84fb0ac..8803a6d136f7a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1799,8 +1799,8 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
       for (Operation *combiningOp : combiningOps)
         rewriter.eraseOp(combiningOp);
     }
-    for (auto [blockArg, out, result] :
-         llvm::zip_equal(blockArgsToDelete, outsToDelete, resultsToDelete)) {
+    for (auto [blockArg, result, out] :
+         llvm::zip_equal(blockArgsToDelete, resultsToDelete, outsToDelete)) {
       rewriter.replaceAllUsesWith(blockArg, out);
       rewriter.replaceAllUsesWith(result, out);
     }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 532697ea432f3..cd067f2cc25b3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -261,7 +261,7 @@ Operation *RewriterBase::eraseOpResults(Operation *op,
   setInsertionPoint(op);
   OperationState state(op->getLoc(), op->getName().getStringRef(),
                        op->getOperands(), newResultTypes, op->getAttrs());
-  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+  for ([[maybe_unused]] auto i : llvm::seq<unsigned>(0, op->getNumRegions()))
     state.addRegion();
   Operation *newOp = create(state);
   for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {



More information about the Mlir-commits mailing list