[Mlir-commits] [mlir] [mlir] Add RewriterBase operand/block-arg mutation helpers (PR #187992)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 01:27:56 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Hocky Yudhiono (hockyy)

<details>
<summary>Changes</summary>

- Add dedicated RewriterBase APIs for in-place operand and block argument updates.
- Migrate existing call sites from `modifyOpInPlace(...)` wrappers to the new helpers.

---

Patch is 20.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/187992.diff


24 Files Affected:

- (modified) mlir/include/mlir/IR/PatternMatch.h (+19) 
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+2-3) 
- (modified) mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (+1-2) 
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+2-4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Split.cpp (+4-6) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+1-4) 
- (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (+1-2) 
- (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+1-2) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp (+1-2) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+1-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+1-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+2-4) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-3) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+1-4) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+46) 
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+1-2) 
- (modified) mlir/lib/Transforms/Utils/CommutativityUtils.cpp (+1-1) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+1-2) 


``````````diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 83477c79ff582..86ae49318f683 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -545,12 +545,31 @@ class RewriterBase : public OpBuilder {
   /// This method erases all operations in a block.
   virtual void eraseBlock(Block *block);
 
+  /// Erase arguments from a block and notify listeners by marking the parent
+  /// operation as modified in-place.
+  /// TODO: Determine a better rollback mode integration for these helpers when
+  /// used by rewriters that support rollback semantics.
+  void eraseBlockArgument(Block *block, unsigned index);
+  void eraseBlockArguments(Block *block, unsigned start, unsigned num);
+  void eraseBlockArguments(Block *block, const BitVector &eraseIndices);
+  void eraseBlockArguments(Block *block,
+                           function_ref<bool(BlockArgument)> shouldEraseFn);
+
   /// Erase the specified results of the given operation. Results cannot be
   /// erased directly, so the implementation creates a new replacement
   /// operation and erases the original operation. The new operation is
   /// returned.
   Operation *eraseOpResults(Operation *op, const BitVector &eraseIndices);
 
+  /// Set operands on an operation and notify listeners by marking the
+  /// operation as modified in-place.
+  /// TODO: Determine a better rollback mode integration for these helpers when
+  /// used by rewriters that support rollback semantics.
+  void setOperands(Operation *op, ValueRange operands);
+  void setOperands(Operation *op, unsigned start, unsigned length,
+                   ValueRange operands);
+  void setOperand(Operation *op, unsigned index, Value value);
+
   /// Inline the operations of block 'source' into block 'dest' before the given
   /// position. The source block will be deleted and must have no uses.
   /// 'argValues' is used to replace the block arguments of 'source'.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index d48a0db4d9de0..57c5d8a663b26 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -860,7 +860,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
   for (auto stream : streams)
     streamDestroyCallBuilder.create(loc, rewriter, {stream});
 
-  rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
+  rewriter.setOperands(yieldOp, newOperands);
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 677c0ba288d40..f98102dc2406a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -178,9 +178,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
                                              replacement);
       }
       // Replace the specific use of the tensor::EmptyOp.
-      rewriter.modifyOpInPlace(user, [&]() {
-        user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
-      });
+      rewriter.setOperand(user, useToBeReplaced->getOperandNumber(),
+                          replacement);
       state.resetCache();
     }
 
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 216401a80c9f8..2fef3e692755c 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -108,8 +108,7 @@ class BranchOpInterfaceTypeConversion
           newOperands[idx] = operands[idx];
       }
     }
-    rewriter.modifyOpInPlace(
-        op, [newOperands, op]() { op->setOperands(newOperands); });
+    rewriter.setOperands(op, newOperands);
     return success();
   }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5d409f71847c6..dd6b122b888fd 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2188,7 +2188,7 @@ struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
         continue;
       validOperands.push_back(operand);
     }
-    rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
+    rewriter.setOperands(op, validOperands);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5f530a585ddb9..eb703d106fae6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1172,10 +1172,8 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
 
   // Replace the use in containingOp.
-  rewriter.modifyOpInPlace(containingOp, [&]() {
-    containingOp->setOperand(pUse->getOperandNumber(),
-                             destinationTensors.front());
-  });
+  rewriter.setOperand(containingOp, pUse->getOperandNumber(),
+                      destinationTensors.front());
 
   return tileAndFuseResult->tiledOps;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
index 6f81702ee22c5..64e732d03f426 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -137,8 +137,7 @@ struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
     // Replace the additive-ident, i.e. zero, out arg of the dominated op by the
     // dominating summand. This makes the dominated op's result the sum of both
     // of addOp's arguments - therefore we replace addOp and it uses by it.
-    rewriter.modifyOpInPlace(
-        dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+    rewriter.setOperand(dominatedOp, 2, dominatingOperand);
     rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
     return success();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index e1dc40d6d37d9..fdc7361224dac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -152,8 +152,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
       LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
 
       rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
-      rewriter.modifyOpInPlace(
-          broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
+      rewriter.setOperand(broadcast, 0, newLoop.getResult(index));
 
       changed = true;
       return WalkResult::interrupt();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 25881701bc44d..42ca031ac5564 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -116,12 +116,10 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
   // Need to pretend that the original op now takes as operands firstResults,
   // otherwise tiling interface implementation will take the wrong value to
   // produce data tiles.
-  rewriter.modifyOpInPlace(op, [&]() {
-    unsigned numTotalOperands = op->getNumOperands();
-    unsigned numOutputOperands = firstResults.size();
-    op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
-                    firstResults);
-  });
+  unsigned numTotalOperands = op->getNumOperands();
+  unsigned numOutputOperands = firstResults.size();
+  rewriter.setOperands(op, numTotalOperands - numOutputOperands,
+                       numOutputOperands, firstResults);
 
   // Create the second part.
   OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 76467154e869f..04fecc80d920f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1582,10 +1582,7 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
       rewriter.replaceAllUsesWith(blockArg, out);
       rewriter.replaceAllUsesWith(result, out);
     }
-    // TODO: There is no rewriter API for erasing block arguments.
-    rewriter.modifyOpInPlace(forallOp, [&]() {
-      forallOp.getBody()->eraseArguments(blockIndicesToDelete);
-    });
+    rewriter.eraseBlockArguments(forallOp.getBody(), blockIndicesToDelete);
 
     // Step 3. Create a new scf.forall op with only the shared_outs/results
     //         that should be retained.
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ddcbda86cf1f3..070c638fec69d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -92,8 +92,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
       SmallVector<Value> yieldOperands = yieldOp.getOperands();
       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
-      rewriter.modifyOpInPlace(yieldOp,
-                               [&]() { yieldOp->setOperands(yieldOperands); });
+      rewriter.setOperands(yieldOp, yieldOperands);
     }
 
     // We cannot do a direct replacement of the forOp since the while op returns
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 9468927021495..c40b0da40078e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -233,8 +233,7 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
   LogicalResult
   matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.modifyOpInPlace(
-        op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
+    rewriter.setOperands(op, flattenValues(adaptor.getOperands()));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index a517ca946f3a4..1e91404f45a09 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -91,8 +91,7 @@ class SPIRVPassThroughConversion : public OpConversionPattern<OpT> {
   LogicalResult
   matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.modifyOpInPlace(op,
-                             [&] { op->setOperands(adaptor.getOperands()); });
+    rewriter.setOperands(op, adaptor.getOperands());
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2c9e9c040d460..08dff0c9d7c4c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1181,9 +1181,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
         continue;
       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
-        rewriter.modifyOpInPlace(&curOp, [&] {
-          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
-        });
+        rewriter.setOperand(&curOp, 0, newFuncOp.getArgument(unrolledInputNo));
         ++unrolledInputIdx;
       }
     }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 0fc5cc76de39c..39ef597b19075 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -571,9 +571,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
       rewriter.setInsertionPoint(linalgOp);
       RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
       Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval);
-      rewriter.modifyOpInPlace(linalgOp, [&]() {
-        linalgOp->setOperand(t->getOperandNumber(), dst);
-      });
+      rewriter.setOperand(linalgOp, t->getOperandNumber(), dst);
 
       // Release the transposed form afterwards.
       // TODO: CSE when used in more than one following op?
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 89ed468d2e1b9..e8248df120f68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1074,7 +1074,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
       auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc());
-      rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
+      rewriter.setOperand(op, 0, convert);
       return success();
     }
     if (encDst) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6004ab26f4663..18e1301d22c74 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -616,10 +616,8 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
     if (def->getBlock() == block) {
       rewriter.setInsertionPoint(def);
       for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
-        rewriter.modifyOpInPlace(def, [&]() {
-          def->setOperand(
-              i, relinkBranch(env, rewriter, block, def->getOperand(i)));
-        });
+        rewriter.setOperand(
+            def, i, relinkBranch(env, rewriter, block, def->getOperand(i)));
       }
     }
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 684c088eb9b0f..96af05c6d3d1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -784,8 +784,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
       Operation *newRed = rewriter.clone(*redExp);
       // Replaces arguments of the reduction expression by using the block
       // arguments from scf.reduce.
-      rewriter.modifyOpInPlace(
-          newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
+      rewriter.setOperands(newRed, redBlock->getArguments());
       // Erases the out-dated reduction expression.
       rewriter.eraseOp(redExp);
       rewriter.setInsertionPointToEnd(redBlock);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..2039665c4aab7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -821,9 +821,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
         auto castOp =
             CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
                            concatOp.getOperand(operandIdx));
-        rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
-          concatOp->setOperand(operandIdx, castOp->getResult(0));
-        });
+        rewriter.setOperand(concatOp, operandIdx, castOp->getResult(0));
       }
     }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..b5e84c37809d9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -310,10 +310,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
   auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
   if (!notOp)
     return failure();
-  rewriter.modifyOpInPlace(op, [&]() {
-    op.getOperation()->setOperands(
-        {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
-  });
+  rewriter.setOperands(op, {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
   return success();
 }
 
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index cd067f2cc25b3..de0cb5cb44c5f 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,6 +244,38 @@ void RewriterBase::eraseBlock(Block *block) {
   block->erase();
 }
 
+void RewriterBase::eraseBlockArgument(Block *block, unsigned index) {
+  eraseBlockArguments(block, index, /*num=*/1);
+}
+
+void RewriterBase::eraseBlockArguments(Block *block, unsigned start,
+                                       unsigned num) {
+  if (Operation *parentOp = block->getParentOp()) {
+    modifyOpInPlace(parentOp, [&]() { block->eraseArguments(start, num); });
+    return;
+  }
+  block->eraseArguments(start, num);
+}
+
+void RewriterBase::eraseBlockArguments(Block *block,
+                                       const BitVector &eraseIndices) {
+  if (Operation *parentOp = block->getParentOp()) {
+    modifyOpInPlace(parentOp, [&]() { block->eraseArguments(eraseIndices); });
+    return;
+  }
+  block->eraseArguments(eraseIndices);
+}
+
+void RewriterBase::eraseBlockArguments(
+    Block *block, function_ref<bool(BlockArgument)> shouldEraseFn) {
+  if (Operation *parentOp = block->getParentOp()) {
+    modifyOpInPlace(parentOp,
+                    [&]() { block->eraseArguments(shouldEraseFn); });
+    return;
+  }
+  block->eraseArguments(shouldEraseFn);
+}
+
 Operation *RewriterBase::eraseOpResults(Operation *op,
                                         const BitVector &eraseIndices) {
   assert(op->getNumResults() == eraseIndices.size() &&
@@ -280,6 +312,20 @@ Operation *RewriterBase::eraseOpResults(Operation *op,
   return newOp;
 }
 
+void RewriterBase::setOperands(Operation *op, ValueRange operands) {
+  modifyOpInPlace(op, [&]() { op->setOperands(operands); });
+}
+
+void RewriterBase::setOperands(Operation *op, unsigned start, unsigned length,
+                               ValueRange operands) {
+  modifyOpInPlace(
+      op, [&]() { op->setOperands(start, length, operands); });
+}
+
+void RewriterBase::setOperand(Operation *op, unsigned index, Value value) {
+  modifyOpInPlace(op, [&]() { op->setOperand(index, value); });
+}
+
 void RewriterBase::finalizeOpModification(Operation *op) {
   // Notify the listener that the operation was modified.
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c3fb73acf5ef0..6a58edeafe1e8 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -930,8 +930,7 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
     for (auto &pair : blockArgsToRemove) {
       Block *block = pair.first;
       BitVector &blockArg = pair.second;
-      rewriter.modifyOpInPlace(block->getParentOp(),
-                               [&]() { block->eraseArguments(blockArg); });
+      rewriter.eraseBlockArguments(block, blockArg);
     }
 
     // Erase op results.
diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
index 8b132b5e484bb..bc20742171fcb 100644
--- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
+++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -303,7 +303,7 @@ class SortCommutativeOperands : public RewritePattern {
       sortedOperands.push_back(commOperand->operand);
     if (sortedOperands == operands)
       return failure();
-  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/187992


More information about the Mlir-commits mailing list