[Mlir-commits] [mlir] [mlir][NFC] Add RewriterBase operand/block-arg mutation helpers (PR #187992)
Hocky Yudhiono
llvmlistbot at llvm.org
Mon Mar 23 01:31:27 PDT 2026
https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/187992
>From d6414df14c7971a7b33f4a78d13f0d6d626dd26d Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 23 Mar 2026 16:23:28 +0800
Subject: [PATCH] [mlir][NFC] Add RewriterBase operand/block-arg mutation
helpers
---
mlir/include/mlir/IR/PatternMatch.h | 19 ++++++++
.../GPUCommon/GPUToLLVMConversion.cpp | 2 +-
.../Transforms/EmptyTensorElimination.cpp | 5 +-
.../Func/Transforms/FuncConversions.cpp | 3 +-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +-
.../TransformOps/LinalgTransformOps.cpp | 6 +--
.../Linalg/Transforms/FoldAddIntoDest.cpp | 3 +-
.../Dialect/Linalg/Transforms/Hoisting.cpp | 3 +-
mlir/lib/Dialect/Linalg/Transforms/Split.cpp | 10 ++--
mlir/lib/Dialect/SCF/IR/SCF.cpp | 5 +-
.../lib/Dialect/SCF/Transforms/ForToWhile.cpp | 3 +-
.../Transforms/StructuralTypeConversions.cpp | 3 +-
.../DecorateCompositeTypeLayoutPass.cpp | 3 +-
.../SPIRV/Transforms/SPIRVConversion.cpp | 4 +-
.../Transforms/SparseReinterpretMap.cpp | 4 +-
.../Transforms/SparseTensorRewriting.cpp | 2 +-
.../Transforms/Sparsification.cpp | 6 +--
.../Transforms/Utils/LoopEmitter.cpp | 3 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 5 +-
mlir/lib/IR/PatternMatch.cpp | 46 +++++++++++++++++++
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 3 +-
.../Transforms/Utils/CommutativityUtils.cpp | 2 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 3 +-
24 files changed, 93 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 83477c79ff582..86ae49318f683 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -545,12 +545,31 @@ class RewriterBase : public OpBuilder {
/// This method erases all operations in a block.
virtual void eraseBlock(Block *block);
+ /// Erase arguments from a block and notify listeners by marking the parent
+ /// operation as modified in-place.
+ /// TODO: Determine a better rollback mode integration for these helpers when
+ /// used by rewriters that support rollback semantics.
+ void eraseBlockArgument(Block *block, unsigned index);
+ void eraseBlockArguments(Block *block, unsigned start, unsigned num);
+ void eraseBlockArguments(Block *block, const BitVector &eraseIndices);
+ void eraseBlockArguments(Block *block,
+ function_ref<bool(BlockArgument)> shouldEraseFn);
+
/// Erase the specified results of the given operation. Results cannot be
/// erased directly, so the implementation creates a new replacement
/// operation and erases the original operation. The new operation is
/// returned.
Operation *eraseOpResults(Operation *op, const BitVector &eraseIndices);
+ /// Set operands on an operation and notify listeners by marking the
+ /// operation as modified in-place.
+ /// TODO: Determine a better rollback mode integration for these helpers when
+ /// used by rewriters that support rollback semantics.
+ void setOperands(Operation *op, ValueRange operands);
+ void setOperands(Operation *op, unsigned start, unsigned length,
+ ValueRange operands);
+ void setOperand(Operation *op, unsigned index, Value value);
+
/// Inline the operations of block 'source' into block 'dest' before the given
/// position. The source block will be deleted and must have no uses.
/// 'argValues' is used to replace the block arguments of 'source'.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index d48a0db4d9de0..57c5d8a663b26 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -860,7 +860,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
for (auto stream : streams)
streamDestroyCallBuilder.create(loc, rewriter, {stream});
- rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
+ rewriter.setOperands(yieldOp, newOperands);
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 677c0ba288d40..f98102dc2406a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -178,9 +178,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
replacement);
}
// Replace the specific use of the tensor::EmptyOp.
- rewriter.modifyOpInPlace(user, [&]() {
- user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
- });
+ rewriter.setOperand(user, useToBeReplaced->getOperandNumber(),
+ replacement);
state.resetCache();
}
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 216401a80c9f8..2fef3e692755c 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -108,8 +108,7 @@ class BranchOpInterfaceTypeConversion
newOperands[idx] = operands[idx];
}
}
- rewriter.modifyOpInPlace(
- op, [newOperands, op]() { op->setOperands(newOperands); });
+ rewriter.setOperands(op, newOperands);
return success();
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5d409f71847c6..dd6b122b888fd 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2188,7 +2188,7 @@ struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
continue;
validOperands.push_back(operand);
}
- rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
+ rewriter.setOperands(op, validOperands);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5f530a585ddb9..eb703d106fae6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1172,10 +1172,8 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
// Replace the use in containingOp.
- rewriter.modifyOpInPlace(containingOp, [&]() {
- containingOp->setOperand(pUse->getOperandNumber(),
- destinationTensors.front());
- });
+ rewriter.setOperand(containingOp, pUse->getOperandNumber(),
+ destinationTensors.front());
return tileAndFuseResult->tiledOps;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
index 6f81702ee22c5..64e732d03f426 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -137,8 +137,7 @@ struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
// Replace the additive-ident, i.e. zero, out arg of the dominated op by the
// dominating summand. This makes the dominated op's result the sum of both
// of addOp's arguments - therefore we replace addOp and it uses by it.
- rewriter.modifyOpInPlace(
- dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+ rewriter.setOperand(dominatedOp, 2, dominatingOperand);
rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index e1dc40d6d37d9..fdc7361224dac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -152,8 +152,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
- rewriter.modifyOpInPlace(
- broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
+ rewriter.setOperand(broadcast, 0, newLoop.getResult(index));
changed = true;
return WalkResult::interrupt();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 25881701bc44d..42ca031ac5564 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -116,12 +116,10 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
// Need to pretend that the original op now takes as operands firstResults,
// otherwise tiling interface implementation will take the wrong value to
// produce data tiles.
- rewriter.modifyOpInPlace(op, [&]() {
- unsigned numTotalOperands = op->getNumOperands();
- unsigned numOutputOperands = firstResults.size();
- op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
- firstResults);
- });
+ unsigned numTotalOperands = op->getNumOperands();
+ unsigned numOutputOperands = firstResults.size();
+ rewriter.setOperands(op, numTotalOperands - numOutputOperands,
+ numOutputOperands, firstResults);
// Create the second part.
OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 76467154e869f..04fecc80d920f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1582,10 +1582,7 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
rewriter.replaceAllUsesWith(blockArg, out);
rewriter.replaceAllUsesWith(result, out);
}
- // TODO: There is no rewriter API for erasing block arguments.
- rewriter.modifyOpInPlace(forallOp, [&]() {
- forallOp.getBody()->eraseArguments(blockIndicesToDelete);
- });
+ rewriter.eraseBlockArguments(forallOp.getBody(), blockIndicesToDelete);
// Step 3. Create a new scf.forall op with only the shared_outs/results
// that should be retained.
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ddcbda86cf1f3..070c638fec69d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -92,8 +92,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
SmallVector<Value> yieldOperands = yieldOp.getOperands();
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
- rewriter.modifyOpInPlace(yieldOp,
- [&]() { yieldOp->setOperands(yieldOperands); });
+ rewriter.setOperands(yieldOp, yieldOperands);
}
// We cannot do a direct replacement of the forOp since the while op returns
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 9468927021495..c40b0da40078e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -233,8 +233,7 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
LogicalResult
matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.modifyOpInPlace(
- op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
+ rewriter.setOperands(op, flattenValues(adaptor.getOperands()));
return success();
}
};
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index a517ca946f3a4..1e91404f45a09 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -91,8 +91,7 @@ class SPIRVPassThroughConversion : public OpConversionPattern<OpT> {
LogicalResult
matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.modifyOpInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.setOperands(op, adaptor.getOperands());
return success();
}
};
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2c9e9c040d460..08dff0c9d7c4c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1181,9 +1181,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
continue;
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
- rewriter.modifyOpInPlace(&curOp, [&] {
- curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
- });
+ rewriter.setOperand(&curOp, 0, newFuncOp.getArgument(unrolledInputNo));
++unrolledInputIdx;
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 0fc5cc76de39c..39ef597b19075 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -571,9 +571,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
rewriter.setInsertionPoint(linalgOp);
RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval);
- rewriter.modifyOpInPlace(linalgOp, [&]() {
- linalgOp->setOperand(t->getOperandNumber(), dst);
- });
+ rewriter.setOperand(linalgOp, t->getOperandNumber(), dst);
// Release the transposed form afterwards.
// TODO: CSE when used in more than one following op?
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 89ed468d2e1b9..e8248df120f68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1074,7 +1074,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc());
- rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
+ rewriter.setOperand(op, 0, convert);
return success();
}
if (encDst) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6004ab26f4663..18e1301d22c74 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -616,10 +616,8 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
if (def->getBlock() == block) {
rewriter.setInsertionPoint(def);
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
- rewriter.modifyOpInPlace(def, [&]() {
- def->setOperand(
- i, relinkBranch(env, rewriter, block, def->getOperand(i)));
- });
+ rewriter.setOperand(
+ def, i, relinkBranch(env, rewriter, block, def->getOperand(i)));
}
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 684c088eb9b0f..96af05c6d3d1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -784,8 +784,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
Operation *newRed = rewriter.clone(*redExp);
// Replaces arguments of the reduction expression by using the block
// arguments from scf.reduce.
- rewriter.modifyOpInPlace(
- newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
+ rewriter.setOperands(newRed, redBlock->getArguments());
// Erases the out-dated reduction expression.
rewriter.eraseOp(redExp);
rewriter.setInsertionPointToEnd(redBlock);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..2039665c4aab7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -821,9 +821,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
auto castOp =
CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
concatOp.getOperand(operandIdx));
- rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
- concatOp->setOperand(operandIdx, castOp->getResult(0));
- });
+ rewriter.setOperand(concatOp, operandIdx, castOp->getResult(0));
}
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..b5e84c37809d9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -310,10 +310,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
- rewriter.modifyOpInPlace(op, [&]() {
- op.getOperation()->setOperands(
- {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
- });
+ rewriter.setOperands(op, {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
return success();
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index cd067f2cc25b3..de0cb5cb44c5f 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,6 +244,38 @@ void RewriterBase::eraseBlock(Block *block) {
block->erase();
}
+void RewriterBase::eraseBlockArgument(Block *block, unsigned index) {
+ eraseBlockArguments(block, index, /*num=*/1);
+}
+
+void RewriterBase::eraseBlockArguments(Block *block, unsigned start,
+ unsigned num) {
+ if (Operation *parentOp = block->getParentOp()) {
+ modifyOpInPlace(parentOp, [&]() { block->eraseArguments(start, num); });
+ return;
+ }
+ block->eraseArguments(start, num);
+}
+
+void RewriterBase::eraseBlockArguments(Block *block,
+ const BitVector &eraseIndices) {
+ if (Operation *parentOp = block->getParentOp()) {
+ modifyOpInPlace(parentOp, [&]() { block->eraseArguments(eraseIndices); });
+ return;
+ }
+ block->eraseArguments(eraseIndices);
+}
+
+void RewriterBase::eraseBlockArguments(
+ Block *block, function_ref<bool(BlockArgument)> shouldEraseFn) {
+ if (Operation *parentOp = block->getParentOp()) {
+ modifyOpInPlace(parentOp,
+ [&]() { block->eraseArguments(shouldEraseFn); });
+ return;
+ }
+ block->eraseArguments(shouldEraseFn);
+}
+
Operation *RewriterBase::eraseOpResults(Operation *op,
const BitVector &eraseIndices) {
assert(op->getNumResults() == eraseIndices.size() &&
@@ -280,6 +312,20 @@ Operation *RewriterBase::eraseOpResults(Operation *op,
return newOp;
}
+void RewriterBase::setOperands(Operation *op, ValueRange operands) {
+ modifyOpInPlace(op, [&]() { op->setOperands(operands); });
+}
+
+void RewriterBase::setOperands(Operation *op, unsigned start, unsigned length,
+ ValueRange operands) {
+ modifyOpInPlace(
+ op, [&]() { op->setOperands(start, length, operands); });
+}
+
+void RewriterBase::setOperand(Operation *op, unsigned index, Value value) {
+ modifyOpInPlace(op, [&]() { op->setOperand(index, value); });
+}
+
void RewriterBase::finalizeOpModification(Operation *op) {
// Notify the listener that the operation was modified.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c3fb73acf5ef0..6a58edeafe1e8 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -930,8 +930,7 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
for (auto &pair : blockArgsToRemove) {
Block *block = pair.first;
BitVector &blockArg = pair.second;
- rewriter.modifyOpInPlace(block->getParentOp(),
- [&]() { block->eraseArguments(blockArg); });
+ rewriter.eraseBlockArguments(block, blockArg);
}
// Erase op results.
diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
index 8b132b5e484bb..bc20742171fcb 100644
--- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
+++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -303,7 +303,7 @@ class SortCommutativeOperands : public RewritePattern {
sortedOperands.push_back(commOperand->operand);
if (sortedOperands == operands)
return failure();
- rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
+ rewriter.setOperands(op, sortedOperands);
return success();
}
};
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 6c44ace831e96..235eed256bec4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -2035,8 +2035,7 @@ struct TestTypeConsumerForward
LogicalResult
matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- rewriter.modifyOpInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.setOperands(op, adaptor.getOperands());
return success();
}
};
More information about the Mlir-commits
mailing list