[Mlir-commits] [flang] [mlir] [mlir][IR] Rename "update root" to "modify op" in rewriter API (PR #78260)
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 16 03:46:45 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/78260
This commit renames 4 pattern rewriter API functions:
* `updateRootInPlace` -> `modifyOpInPlace`
* `startRootUpdate` -> `startOpModification`
* `finalizeRootUpdate` -> `finalizeOpModification`
* `cancelRootUpdate` -> `cancelOpModification`
The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). A rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns.
Note: The new function names use the term "modify" instead of "update" for consistency with the `RewriterBase::Listener` terminology (`notifyOperationModified`).
>From e525c8360c034125a787870c16b9d33e24224428 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 16 Jan 2024 11:39:05 +0000
Subject: [PATCH] [mlir][IR] Rename "update root" to "modify op" in rewriter
API
This commit renames 4 pattern rewriter API functions:
* `updateRootInPlace` -> `modifyOpInPlace`
* `startRootUpdate` -> `startOpModification`
* `finalizeRootUpdate` -> `finalizeOpModification`
* `cancelRootUpdate` -> `cancelOpModification`
The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). Rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns.
Note: The new function names use the term "modify" instead of "update" for consistency with the `RewriterBase::Listener` terminology (`notifyOperationModified`).
---
.../lib/Optimizer/CodeGen/BoxedProcedure.cpp | 20 ++++-----
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 8 ++--
.../HLFIR/Transforms/BufferizeHLFIR.cpp | 8 ++--
.../Optimizer/Transforms/AffineDemotion.cpp | 4 +-
.../Optimizer/Transforms/AffinePromotion.cpp | 12 ++---
.../Transforms/ExternalNameConversion.cpp | 8 ++--
mlir/docs/PatternRewriter.md | 10 ++---
.../lib/Standalone/StandalonePasses.cpp | 2 +-
.../toy/Ch5/mlir/LowerToAffineLoops.cpp | 4 +-
.../toy/Ch6/mlir/LowerToAffineLoops.cpp | 4 +-
.../toy/Ch7/mlir/LowerToAffineLoops.cpp | 4 +-
mlir/include/mlir/IR/PatternMatch.h | 44 ++++++++++---------
.../mlir/Transforms/DialectConversion.h | 14 +++---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 2 +-
.../GPUCommon/GPUToLLVMConversion.cpp | 3 +-
.../Conversion/OpenACCToSCF/OpenACCToSCF.cpp | 5 +--
.../Conversion/VectorToSCF/VectorToSCF.cpp | 6 +--
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 6 +--
.../Affine/Transforms/DecomposeAffineOps.cpp | 4 +-
.../ArmSME/Transforms/TileAllocation.cpp | 4 +-
.../Transforms/LegalizeForLLVMExport.cpp | 4 +-
.../IR/BufferizableOpInterface.cpp | 4 +-
.../Bufferization/IR/BufferizationOps.cpp | 2 +-
.../BufferDeallocationSimplification.cpp | 2 +-
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 8 ++--
.../Transforms/DecomposeCallGraphTypes.cpp | 2 +-
.../Func/Transforms/FuncConversions.cpp | 6 +--
.../Func/Transforms/OneToNFuncConversions.cpp | 4 +-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 4 +-
.../LLVMIR/Transforms/TypeConsistency.cpp | 14 +++---
.../TransformOps/LinalgTransformOps.cpp | 4 +-
.../Transforms/ConvertToDestinationStyle.cpp | 8 ++--
.../Dialect/Linalg/Transforms/Detensorize.cpp | 6 +--
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 6 +--
.../Transforms/EliminateEmptyTensors.cpp | 2 +-
.../EraseUnusedOperandsAndResults.cpp | 4 +-
.../Linalg/Transforms/HoistPadding.cpp | 4 +-
.../Dialect/Linalg/Transforms/Interchange.cpp | 6 +--
mlir/lib/Dialect/Linalg/Transforms/Split.cpp | 2 +-
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +-
.../Linalg/Transforms/Vectorization.cpp | 2 +-
.../Dialect/MemRef/IR/MemRefMemorySlot.cpp | 4 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 ++--
.../Transforms/ExpandStridedMetadata.cpp | 2 +-
.../Transforms/IndependenceTransforms.cpp | 2 +-
.../Dialect/MemRef/Transforms/MultiBuffer.cpp | 4 +-
.../NVGPU/Transforms/MmaSyncTF32Transform.cpp | 2 +-
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 4 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 28 ++++++------
.../BufferizableOpInterfaceImpl.cpp | 4 +-
.../lib/Dialect/SCF/Transforms/ForToWhile.cpp | 4 +-
.../SCF/Transforms/LoopCanonicalization.cpp | 4 +-
.../SCF/Transforms/LoopSpecialization.cpp | 10 ++---
.../SCF/Transforms/OneToNTypeConversion.cpp | 4 +-
.../Transforms/StructuralTypeConversions.cpp | 2 +-
.../SCF/Transforms/TileUsingInterface.cpp | 2 +-
.../DecorateCompositeTypeLayoutPass.cpp | 4 +-
.../Transforms/LowerABIAttributesPass.cpp | 2 +-
.../IR/SparseTensorInterfaces.cpp | 2 +-
.../Transforms/SparseReinterpretMap.cpp | 20 ++++-----
.../Transforms/SparseTensorRewriting.cpp | 6 +--
.../Transforms/SparseVectorization.cpp | 2 +-
.../Transforms/Sparsification.cpp | 2 +-
.../Transforms/Utils/LoopEmitter.cpp | 2 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 14 +++---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 +-
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../Vector/Transforms/VectorDistribute.cpp | 14 +++---
.../VectorTransferSplitRewritePatterns.cpp | 4 +-
.../Vector/Transforms/VectorTransforms.cpp | 2 +-
mlir/lib/IR/PatternMatch.cpp | 4 +-
mlir/lib/Transforms/Mem2Reg.cpp | 4 +-
.../Transforms/Utils/CommutativityUtils.cpp | 2 +-
.../Transforms/Utils/DialectConversion.cpp | 10 ++---
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 30 ++++++-------
77 files changed, 243 insertions(+), 243 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 24cf2f39fc9a09..7d73af4d7103dc 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -215,14 +215,14 @@ class BoxedProcedurePass
rewriter.replaceOpWithNewOp<ConvertOp>(
addr, typeConverter.convertType(addr.getType()), addr.getVal());
} else if (typeConverter.needsConversion(resTy)) {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
op->getResult(0).setType(typeConverter.convertType(resTy));
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
}
} else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
mlir::FunctionType ty = func.getFunctionType();
if (typeConverter.needsConversion(ty)) {
- rewriter.startRootUpdate(func);
+ rewriter.startOpModification(func);
auto toTy =
typeConverter.convertType(ty).cast<mlir::FunctionType>();
if (!func.empty())
@@ -235,7 +235,7 @@ class BoxedProcedurePass
block.eraseArgument(i + 1);
}
func.setType(toTy);
- rewriter.finalizeRootUpdate(func);
+ rewriter.finalizeOpModification(func);
}
} else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
// Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
@@ -273,10 +273,10 @@ class BoxedProcedurePass
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
auto ty = global.getType();
if (typeConverter.needsConversion(ty)) {
- rewriter.startRootUpdate(global);
+ rewriter.startOpModification(global);
auto toTy = typeConverter.convertType(ty);
global.setType(toTy);
- rewriter.finalizeRootUpdate(global);
+ rewriter.finalizeOpModification(global);
}
} else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
auto ty = mem.getType();
@@ -339,17 +339,17 @@ class BoxedProcedurePass
mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
}
} else if (op->getDialect() == firDialect) {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
for (auto i : llvm::enumerate(op->getResultTypes()))
if (typeConverter.needsConversion(i.value())) {
auto toTy = typeConverter.convertType(i.value());
op->getResult(i.index()).setType(toTy);
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
}
// Ensure block arguments are updated if needed.
if (op->getNumRegions() != 0) {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
for (mlir::Region ®ion : op->getRegions())
for (mlir::Block &block : region.getBlocks())
for (mlir::BlockArgument blockArg : block.getArguments())
@@ -358,7 +358,7 @@ class BoxedProcedurePass
typeConverter.convertType(blockArg.getType());
blockArg.setType(toTy);
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
}
});
}
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index e07732d57880c5..f2c731d47909a9 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3763,13 +3763,13 @@ class RenameMSVCLibmCallees
mlir::LogicalResult
matchAndRewrite(mlir::LLVM::CallOp op,
mlir::PatternRewriter &rewriter) const override {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
auto callee = op.getCallee();
if (callee)
if (callee->equals("hypotf"))
op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf"));
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return mlir::success();
}
};
@@ -3782,10 +3782,10 @@ class RenameMSVCLibmFuncs
mlir::LogicalResult
matchAndRewrite(mlir::LLVM::LLVMFuncOp op,
mlir::PatternRewriter &rewriter) const override {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
if (op.getSymName().equals("hypotf"))
op.setSymNameAttr(rewriter.getStringAttr("_hypotf"));
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 97127f57cc3eb9..641854bd201f0b 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -256,9 +256,9 @@ struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
llvm::SmallVector<mlir::Value> newOperands;
for (mlir::Value operand : adaptor.getOperands())
newOperands.push_back(getBufferizedExprStorage(operand));
- rewriter.startRootUpdate(assign);
+ rewriter.startOpModification(assign);
assign->setOperands(newOperands);
- rewriter.finalizeRootUpdate(assign);
+ rewriter.finalizeOpModification(assign);
return mlir::success();
}
};
@@ -834,9 +834,9 @@ struct ElementalOpConversion
// Explicitly delete the body of the elemental to get rid
// of any users of hlfir.expr values inside the body as early
// as possible.
- rewriter.startRootUpdate(elemental);
+ rewriter.startOpModification(elemental);
rewriter.eraseBlock(elemental.getBody());
- rewriter.finalizeRootUpdate(elemental);
+ rewriter.finalizeOpModification(elemental);
rewriter.replaceOp(elemental, bufferizedExpr);
return mlir::success();
}
diff --git a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
index 0c256deeca4161..da29ae880700e6 100644
--- a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
@@ -114,9 +114,9 @@ class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
op.getValue());
return success();
}
- rewriter.startRootUpdate(op->getParentOp());
+ rewriter.startOpModification(op->getParentOp());
op.getResult().replaceAllUsesWith(op.getValue());
- rewriter.finalizeRootUpdate(op->getParentOp());
+ rewriter.finalizeOpModification(op->getParentOp());
rewriter.eraseOp(op);
}
return success();
diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
index af2200f6a7b02d..d1831cf1c200cc 100644
--- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -464,15 +464,15 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
auto affineFor = loopAndIndex.first;
auto inductionVar = loopAndIndex.second;
- rewriter.startRootUpdate(affineFor.getOperation());
+ rewriter.startOpModification(affineFor.getOperation());
affineFor.getBody()->getOperations().splice(
std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
std::prev(loopOps.end()));
- rewriter.finalizeRootUpdate(affineFor.getOperation());
+ rewriter.finalizeOpModification(affineFor.getOperation());
- rewriter.startRootUpdate(loop.getOperation());
+ rewriter.startOpModification(loop.getOperation());
loop.getInductionVar().replaceAllUsesWith(inductionVar);
- rewriter.finalizeRootUpdate(loop.getOperation());
+ rewriter.finalizeOpModification(loop.getOperation());
rewriteMemoryOps(affineFor.getBody(), rewriter);
@@ -561,7 +561,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
auto affineIf = rewriter.create<affine::AffineIfOp>(
op.getLoc(), affineCondition.getIntegerSet(),
affineCondition.getAffineArgs(), !op.getElseRegion().empty());
- rewriter.startRootUpdate(affineIf);
+ rewriter.startOpModification(affineIf);
affineIf.getThenBlock()->getOperations().splice(
std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
std::prev(ifOps.end()));
@@ -571,7 +571,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
std::prev(otherOps.end()));
}
- rewriter.finalizeRootUpdate(affineIf);
+ rewriter.finalizeOpModification(affineIf);
rewriteMemoryOps(affineIf.getBody(), rewriter);
LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
index 221e93ff85e18e..bc5be3f196b81a 100644
--- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
@@ -76,7 +76,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
matchAndRewrite(mlir::func::FuncOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult ret = success();
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
llvm::StringRef oldName = op.getSymName();
auto result = fir::NameUniquer::deconstruct(oldName);
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
@@ -95,7 +95,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
}
updateEarlyOutliningParentName(op, appendUnderscore);
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return ret;
}
@@ -114,7 +114,7 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
mlir::LogicalResult
matchAndRewrite(fir::GlobalOp op,
mlir::PatternRewriter &rewriter) const override {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
auto result = fir::NameUniquer::deconstruct(
op.getSymref().getRootReference().getValue());
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
@@ -122,7 +122,7 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName));
SymbolTable::setSymbolName(op, newName);
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return success();
}
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 8fe5ef35a76039..011cd14175634b 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -213,15 +213,15 @@ user is determined by the specific pattern driver.
This method replaces an operation's results with a set of provided values, and
erases the operation.
-* Update an Operation in-place : `(start|cancel|finalize)RootUpdate`
+* Update an Operation in-place : `(start|cancel|finalize)OpModification`
This is a collection of methods that provide a transaction-like API for updating
the attributes, location, operands, or successors of an operation in-place
within a pattern. An in-place update transaction is started with
-`startRootUpdate`, and may either be canceled or finalized with
-`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper,
-`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a
-callback.
+`startOpModification`, and may either be canceled or finalized with
+`cancelOpModification` and `finalizeOpModification` respectively. A convenience
+wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize`
+around a callback.
* OpBuilder API
diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
index d438cb46ecdada..a23d0420f04350 100644
--- a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
+++ b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
@@ -24,7 +24,7 @@ class StandaloneSwitchBarFooRewriter : public OpRewritePattern<func::FuncOp> {
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
if (op.getSymName() == "bar") {
- rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
+ rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); });
return success();
}
return failure();
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 240b9f9338665a..ae4bd980c34b53 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index 240b9f9338665a..ae4bd980c34b53 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 240b9f9338665a..ae4bd980c34b53 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9b4fa65bff49e1..b065d4e8d37689 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -585,28 +585,30 @@ class RewriterBase : public OpBuilder {
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
- /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
- /// This is a minor efficiency win (it avoids creating a new operation and
- /// removing the old one) but also often allows simpler code in the client.
- virtual void startRootUpdate(Operation *op) {}
-
- /// This method is used to signal the end of a root update on the given
- /// operation. This can only be called on operations that were provided to a
- /// call to `startRootUpdate`.
- virtual void finalizeRootUpdate(Operation *op);
-
- /// This method cancels a pending root update. This can only be called on
- /// operations that were provided to a call to `startRootUpdate`.
- virtual void cancelRootUpdate(Operation *op) {}
-
- /// This method is a utility wrapper around a root update of an operation. It
- /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
- /// callable.
+ /// followed by a call to either `finalizeOpModification` or
+ /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
+ /// a new operation and removing the old one) but also often allows simpler
+ /// code in the client.
+ virtual void startOpModification(Operation *op) {}
+
+ /// This method is used to signal the end of an in-place modification of the
+ /// given operation. This can only be called on operations that were provided
+ /// to a call to `startOpModification`.
+ virtual void finalizeOpModification(Operation *op);
+
+ /// This method cancels a pending in-place modification. This can only be
+ /// called on operations that were provided to a call to
+ /// `startOpModification`.
+ virtual void cancelOpModification(Operation *op) {}
+
+ /// This method is a utility wrapper around an in-place modification of an
+ /// operation. It wraps calls to `startOpModification` and
+ /// `finalizeOpModification` around the given callable.
template <typename CallableT>
- void updateRootInPlace(Operation *root, CallableT &&callable) {
- startRootUpdate(root);
+ void modifyOpInPlace(Operation *root, CallableT &&callable) {
+ startOpModification(root);
callable();
- finalizeRootUpdate(root);
+ finalizeOpModification(root);
}
/// Find uses of `from` and replace them with `to`. It also marks every
@@ -619,7 +621,7 @@ class RewriterBase : public OpBuilder {
void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
- updateRootInPlace(op, [&]() { operand.set(to); });
+ modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
void replaceAllUsesWith(ValueRange from, ValueRange to) {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c5725e9c856256..9568540789df3f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -739,17 +739,17 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;
- /// PatternRewriter hook for updating the root operation in-place.
- /// Note: These methods only track updates to the top-level operation itself,
+ /// PatternRewriter hook for updating the given operation in-place.
+ /// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
- void startRootUpdate(Operation *op) override;
+ void startOpModification(Operation *op) override;
- /// PatternRewriter hook for updating the root operation in-place.
- void finalizeRootUpdate(Operation *op) override;
+ /// PatternRewriter hook for updating the given operation in-place.
+ void finalizeOpModification(Operation *op) override;
- /// PatternRewriter hook for updating the root operation in-place.
- void cancelRootUpdate(Operation *op) override;
+ /// PatternRewriter hook for updating the given operation in-place.
+ void cancelOpModification(Operation *op) override;
/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 16214d72fcddc2..bbef3b996e40b8 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
// Step 2. Assign the op a real tile ID.
// For simplicity, we always use tile 0 (which always exists).
auto zeroTileId = rewriter.getI32IntegerAttr(0);
- rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
+ rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
VectorType tileVectorType = tileOp.getTileType();
auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 94df3765a67e74..f853d5c47b623c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -918,8 +918,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
for (auto stream : streams)
streamDestroyCallBuilder.create(loc, rewriter, {stream});
- rewriter.updateRootInPlace(yieldOp,
- [&] { yieldOp->setOperands(newOperands); });
+ rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
return success();
}
diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
index 8c1a7d9c6b2a43..54e6bec12b897c 100644
--- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
+++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
@@ -43,14 +43,13 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
op.getIfCond(), false);
- rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
thenBodyBuilder.clone(*op.getOperation());
rewriter.eraseOp(op);
} else {
if (constAttr.getInt())
- rewriter.updateRootInPlace(op,
- [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
rewriter.eraseOp(op);
}
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 44fbac1935fed7..f8485e02a2208e 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -645,13 +645,13 @@ struct PrepareTransferWriteConversion
rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getVectorMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
});
if (xferOp.getMask()) {
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(buffers.maskBuffer);
});
}
@@ -966,7 +966,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
loadIndices, iv);
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
- rewriter.updateRootInPlace(newXfer, [&]() {
+ rewriter.modifyOpInPlace(newXfer, [&]() {
newXfer.getMaskMutable().assign(mask);
});
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d5be2e906989fa..c260e68d509e98 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2493,7 +2493,7 @@ FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
assert(newInitOperands.size() == newYieldedValues.size() &&
"expected as many new yield values as new iter operands");
- rewriter.updateRootInPlace(yieldOp, [&]() {
+ rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getOperandsMutable().append(newYieldedValues);
});
}
@@ -2686,9 +2686,9 @@ struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
!llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
return failure();
- rewriter.startRootUpdate(ifOp);
+ rewriter.startOpModification(ifOp);
rewriter.eraseBlock(ifOp.getElseBlock());
- rewriter.finalizeRootUpdate(ifOp);
+ rewriter.finalizeOpModification(ifOp);
return success();
}
};
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index e5501e848c1646..f28fb3acb7db7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -71,10 +71,10 @@ void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
op->getContext());
canonicalizeMapAndOperands(&map, &operands);
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
op.setMap(map);
op->setOperands(operands);
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
}
/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 49ea6bb5f8614e..6b224fab3d98ce 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -218,7 +218,7 @@ struct AssignTileIDsPattern
return defaultVal;
};
auto setDiscardableIntAttr = [&](StringRef name, auto value) {
- rewriter.updateRootInPlace(tileOp, [&] {
+ rewriter.modifyOpInPlace(tileOp, [&] {
func->setDiscardableAttr(name,
rewriter.getI32IntegerAttr((unsigned)value));
});
@@ -263,7 +263,7 @@ struct AssignTileIDsPattern
SetVector<Operation *> dependantOps;
findDependantOps(tileOp->getResult(0), dependantOps);
auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
- rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
+ rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
for (auto *op : dependantOps) {
if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
auto currentTileId = dependantTileOp.getTileId();
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 92278c0d74d574..32c87c1b824074 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -30,8 +30,8 @@ class ForwardOperands : public OpConversionPattern<OpTy> {
if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
return rewriter.notifyMatchFailure(op, "operand types already match");
- rewriter.updateRootInPlace(
- op, [&]() { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index a0bb8715f2c561..4b1dfee4a2b926 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -253,7 +253,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
- rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
+ rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
}
// Insert copies of Values.
@@ -274,7 +274,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// dynamic extents. Do not update these either.
if (isa<tensor::DimOp>(use->getOwner()))
continue;
- rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
+ rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
}
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 94bc2bcea63be9..253fcf2525121b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -895,7 +895,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
deallocOp.getConditions() == conditions)
return failure();
- rewriter.updateRootInPlace(deallocOp, [&]() {
+ rewriter.modifyOpInPlace(deallocOp, [&]() {
deallocOp.getMemrefsMutable().assign(memrefs);
deallocOp.getConditionsMutable().assign(conditions);
});
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 42653517249d66..75d65193809f10 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -42,7 +42,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
deallocOp.getConditions() == conditions)
return failure();
- rewriter.updateRootInPlace(deallocOp, [&]() {
+ rewriter.modifyOpInPlace(deallocOp, [&]() {
deallocOp.getMemrefsMutable().assign(memrefs);
deallocOp.getConditionsMutable().assign(conditions);
});
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 999c04e48ee168..d242d75bd51fa7 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -403,8 +403,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
constantTrue = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
- rewriter.updateRootInPlace(use.getOwner(),
- [&] { use.set(constantTrue); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&] { use.set(constantTrue); });
}
}
}
@@ -418,8 +418,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
constantFalse = rewriter.create<arith::ConstantOp>(
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
- rewriter.updateRootInPlace(use.getOwner(),
- [&] { use.set(constantFalse); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&] { use.set(constantFalse); });
}
}
}
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 98ae826b6497fb..fa030cb18e035d 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -86,7 +86,7 @@ struct DecomposeCallGraphTypesForFuncArgs
if (failed(typeConverter->convertTypes(functionType.getResults(),
newResultTypes)))
return failure();
- rewriter.updateRootInPlace(op, [&] {
+ rewriter.modifyOpInPlace(op, [&] {
op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
newResultTypes));
});
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 742830ec722f17..d1f3b56dbed738 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -84,7 +84,7 @@ class BranchOpInterfaceTypeConversion
newOperands[idx] = operands[idx];
}
}
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [newOperands, op]() { op->setOperands(newOperands); });
return success();
}
@@ -107,8 +107,8 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
ConversionPatternRewriter &rewriter) const final {
// For a return, all operands go to the results of the parent, so
// rewrite them all.
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
index 70056932411215..c04986cad84f9d 100644
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -80,7 +80,7 @@ class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
auto newType = FunctionType::get(rewriter.getContext(),
argumentMapping.getConvertedTypes(),
funcResultMapping.getConvertedTypes());
- rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
+ rewriter.modifyOpInPlace(op, [&] { op.setType(newType); });
// Update block signatures.
if (!op.isExternal()) {
@@ -105,7 +105,7 @@ class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
return failure();
// Convert operands.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 514b3e9a6e8a56..30b6cd74147e6f 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2030,7 +2030,7 @@ struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
continue;
validOperands.push_back(operand);
}
- rewriter.updateRootInPlace(op, [&]() { op->setOperands(validOperands); });
+ rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 96a0ef591c1cfe..bf24194d03ddb2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -301,7 +301,7 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
// the variable has been optimized out.
auto undef =
rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
- rewriter.updateRootInPlace(*this, [&] { getValueMutable().assign(undef); });
+ rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); });
return DeletionKind::Keep;
}
@@ -394,7 +394,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
return DeletionKind::Delete;
}
- rewriter.updateRootInPlace(*this, [&]() {
+ rewriter.modifyOpInPlace(*this, [&]() {
// Rewire the indices by popping off the second index.
// Start with a single zero, then add the indices beyond the second.
SmallVector<int32_t> newIndices(1);
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index cf900ac0be8fd2..72f9295749a66b 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -83,8 +83,8 @@ static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
op.getAddr(), firstTypeIndices);
- rewriter.updateRootInPlace(op,
- [&]() { op.getAddrMutable().assign(properPtr); });
+ rewriter.modifyOpInPlace(op,
+ [&]() { op.getAddrMutable().assign(properPtr); });
}
template <>
@@ -111,8 +111,8 @@ LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
rewriter.setInsertionPointAfterValue(load.getResult());
BitcastOp bitcast = rewriter.create<BitcastOp>(
load->getLoc(), load.getResult().getType(), load.getResult());
- rewriter.updateRootInPlace(load,
- [&]() { load.getResult().setType(firstType); });
+ rewriter.modifyOpInPlace(load,
+ [&]() { load.getResult().setType(firstType); });
rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
bitcast);
}
@@ -141,7 +141,7 @@ LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
store, [&]() { store.getValueMutable().assign(store.getValue()); });
return success();
@@ -630,8 +630,8 @@ LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
auto bitcastOp =
rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
- rewriter.updateRootInPlace(
- store, [&] { store.getValueMutable().assign(bitcastOp); });
+ rewriter.modifyOpInPlace(store,
+ [&] { store.getValueMutable().assign(bitcastOp); });
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 139566d350fe83..f7cfe8abddb2e8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -785,7 +785,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
// Replace the use in containingOp.
- rewriter.updateRootInPlace(containingOp, [&]() {
+ rewriter.modifyOpInPlace(containingOp, [&]() {
containingOp->setOperand(pUse->getOperandNumber(),
destinationTensors.front());
});
@@ -835,7 +835,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
fusedOp = rewriter.clone(*producerOp);
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
return fusedOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index d8df5d82e28759..ff13aaf9b4abca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -311,7 +311,7 @@ Value linalg::bufferizeToAllocation(
auto toTensorOp =
resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
assert(toTensorOp && "expected to_tensor op");
- rewriter.updateRootInPlace(toTensorOp, [&]() {
+ rewriter.modifyOpInPlace(toTensorOp, [&]() {
toTensorOp.setRestrict(true);
toTensorOp.setWritable(true);
});
@@ -559,11 +559,11 @@ Value linalg::bufferizeToAllocation(
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
}
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
operand->set(toTensorOp);
if (options.bufferizeDestinationOnly) {
- rewriter.updateRootInPlace(toTensorOp, [&]() {
+ rewriter.modifyOpInPlace(toTensorOp, [&]() {
toTensorOp.setRestrict(true);
toTensorOp.setWritable(true);
});
@@ -584,7 +584,7 @@ Value linalg::bufferizeToAllocation(
for (OpOperand *resultUse : resultUses) {
auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
assert(toTensorOp && "expected to_tensor op");
- rewriter.updateRootInPlace(toTensorOp, [&]() {
+ rewriter.modifyOpInPlace(toTensorOp, [&]() {
toTensorOp.setRestrict(true);
toTensorOp.setWritable(true);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index bf91a708ae1589..98cd0444760ece 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -104,7 +104,7 @@ struct FunctionNonEntryBlockConversion
LogicalResult
matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
Region ®ion = op.getFunctionBody();
SmallVector<TypeConverter::SignatureConversion, 2> conversions;
@@ -125,11 +125,11 @@ struct FunctionNonEntryBlockConversion
if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
conversions))) {
- rewriter.cancelRootUpdate(op);
+ rewriter.cancelOpModification(op);
return failure();
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 031f5c7a5d4783..e4cb2f223f3c7e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1816,7 +1816,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
@@ -1843,10 +1843,10 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
}
}
if (!modifiedOutput) {
- rewriter.cancelRootUpdate(op);
+ rewriter.cancelOpModification(op);
return failure();
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index f28f8f0d34a4da..81669a1807796c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -87,7 +87,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
}
// Turn the "in" into an "out".
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
out->set(in->get());
// The original "in" could be removed entirely here (because it will no
// longer have any uses in the payload), but we delegate this to
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 3378eda2bd6734..16ab45ea8bee63 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -354,7 +354,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
// Directly replace the cycle with the blockArg such that
// Deduplicate pattern can eliminate it along with unused yield.
rewriter.replaceOp(cycleOp, outputArg);
- rewriter.updateRootInPlace(genericOp, [] {});
+ rewriter.modifyOpInPlace(genericOp, [] {});
hasRemovedCycles = true;
}
@@ -404,7 +404,7 @@ struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
return failure();
// Rewrite the op.
- rewriter.updateRootInPlace(genericOp, [&]() {
+ rewriter.modifyOpInPlace(genericOp, [&]() {
for (auto [before, after] : replacements) {
BlockArgument bbArg = genericOp.getBody()->getArgument(before);
BlockArgument replacement = genericOp.getBody()->getArgument(after);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 805c9d4ed3b79f..b32ea8eebaecb9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -854,10 +854,10 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
LLVM_DEBUG(DBGS() << "with result #"
<< numOriginalForOpResults + iterArgNumber
<< " of forOp, giving us: " << extracted << "\n");
- rewriter.startRootUpdate(extracted);
+ rewriter.startOpModification(extracted);
extracted.getSourceMutable().assign(
newForOp.getResult(numOriginalForOpResults + iterArgNumber));
- rewriter.finalizeRootUpdate(extracted);
+ rewriter.finalizeOpModification(extracted);
LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
<< "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index f46ba71599b3fd..a0faeb524c57db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -60,9 +60,9 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
assert(permutationMap && "unexpected null map");
// Start a guarded inplace update.
- rewriter.startRootUpdate(genericOp);
- auto guard =
- llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
+ rewriter.startOpModification(genericOp);
+ auto guard = llvm::make_scope_exit(
+ [&]() { rewriter.finalizeOpModification(genericOp); });
// 2. Compute the interchanged indexing maps.
SmallVector<AffineMap> newIndexingMaps;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index bbe3a542f66b88..0174db45a83db2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -113,7 +113,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
// Need to pretend that the original op now takes as operands firstResults,
// otherwise tiling interface implementation will take the wrong value to
// produce data tiles.
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
unsigned numTotalOperands = op->getNumOperands();
unsigned numOutputOperands = firstResults.size();
op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7f3ab1f1a24b2f..ebf80e3c5dc685 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -722,7 +722,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// We cannot use a IRMapping here because it can replace
// different OpOperands with the same value.
Operation *clonedOp = b.clone(*op.getOperation());
- b.updateRootInPlace(clonedOp, [&]() {
+ b.modifyOpInPlace(clonedOp, [&]() {
for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
tiledDpsInitOperands)) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index dc348ea827cde1..0610f24ddaf471 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1952,7 +1952,7 @@ struct PadOpVectorizationWithTransferReadPattern
if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
return failure();
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
rewriter.getBoolArrayAttr(inBounds));
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index be301c191d5139..561b8619032cce 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -227,7 +227,7 @@ DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
Attribute index = getAttributeIndexFromIndexOperands(
getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
- rewriter.updateRootInPlace(*this, [&]() {
+ rewriter.modifyOpInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
getIndicesMutable().clear();
});
@@ -280,7 +280,7 @@ DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
Attribute index = getAttributeIndexFromIndexOperands(
getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
- rewriter.updateRootInPlace(*this, [&]() {
+ rewriter.modifyOpInPlace(*this, [&]() {
setMemRef(memorySlot.ptr);
getIndicesMutable().clear();
});
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 394640f9ebac89..b79ab8f3d671e0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -792,7 +792,7 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
if (fromType && toType) {
if (fromType.getShape() == toType.getShape() &&
fromType.getElementType() == toType.getElementType()) {
- rewriter.updateRootInPlace(copyOp, [&] {
+ rewriter.modifyOpInPlace(copyOp, [&] {
copyOp.getSourceMutable().assign(castOp.getSource());
});
modified = true;
@@ -808,7 +808,7 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
if (fromType && toType) {
if (fromType.getShape() == toType.getShape() &&
fromType.getElementType() == toType.getElementType()) {
- rewriter.updateRootInPlace(copyOp, [&] {
+ rewriter.modifyOpInPlace(copyOp, [&] {
copyOp.getTargetMutable().assign(castOp.getSource());
});
modified = true;
@@ -1366,7 +1366,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
.getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
- // updateRootInplace: lambda cannot capture structured bindings in C++17
+ // modifyOpInPlace: lambda cannot capture structured bindings in C++17
// yet.
op->replaceUsesOfWith(result, constantVal);
atLeastOneReplacement = true;
@@ -2436,7 +2436,7 @@ struct CollapseShapeOpMemRefCastFolder
op.getReassociationIndices());
if (newResultType == op.getResultType()) {
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
} else {
Value newOp = rewriter.create<CollapseShapeOp>(
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 101e099d2b644c..8047c60187b2fd 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -797,7 +797,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
if (!viewLikeOp)
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
- rewriter.updateRootInPlace(extractOp, [&]() {
+ rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
});
return success();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 03765e95b01e7a..10ba508265e7b9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -154,7 +154,7 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
for (OpOperand &operand : user->getOpOperands()) {
if ([[maybe_unused]] auto castOp =
operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
user, [&]() { operand.set(conversion->getOperand(0)); });
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 397bd5856bcb07..bc0dd034f63851 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -79,9 +79,9 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
// TODO: can we use an early_inc iterator?
for (OpOperand *operand : operandsToReplace) {
Operation *op = operand->getOwner();
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
operand->set(val);
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
}
// Perform late op erasure.
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 8bfb4be5225f4a..8163f428683d8d 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -54,7 +54,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
"for nvgpu.mma.sync on f32 datatype");
if (precision == MmaSyncF32Lowering::TF32) {
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bf3264b5da9802..8698c00d1cb728 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -359,7 +359,7 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
if (!matchPattern(ifCond, m_Constant(&constAttr)))
return failure();
if (constAttr.getInt())
- rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
rewriter.eraseOp(op);
@@ -398,7 +398,7 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
if (!matchPattern(ifCond, m_Constant(&constAttr)))
return failure();
if (constAttr.getInt())
- rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+ rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
replaceOpWithRegion(rewriter, op, op.getRegion());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index cdc0b6f1696ae9..45cc7479f209b5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -552,7 +552,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
assert(newInitOperands.size() == newYieldedValues.size() &&
"expected as many new yield values as new iter operands");
- rewriter.updateRootInPlace(yieldOp, [&]() {
+ rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getResultsMutable().append(newYieldedValues);
});
}
@@ -1444,7 +1444,7 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
Value sharedOut =
forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
->get();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
return success();
}
@@ -1464,7 +1464,7 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
failed(foldDynamicIndexList(mixedStep)))
return failure();
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
@@ -1556,7 +1556,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
for (const auto &namedAttr : op->getAttrs()) {
if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
continue;
- rewriter.updateRootInPlace(newOp, [&]() {
+ rewriter.modifyOpInPlace(newOp, [&]() {
newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
});
}
@@ -2023,8 +2023,8 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
[&](OpResult result) {
return yieldOp.getOperand(result.getResultNumber());
});
- rewriter.updateRootInPlace(yieldOp,
- [&]() { yieldOp->setOperands(usedOperands); });
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->setOperands(usedOperands); });
}
LogicalResult matchAndRewrite(IfOp op,
@@ -2189,8 +2189,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
constantTrue = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
- rewriter.updateRootInPlace(use.getOwner(),
- [&]() { use.set(constantTrue); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&]() { use.set(constantTrue); });
} else if (op.getElseRegion().isAncestor(
use.getOwner()->getParentRegion())) {
changed = true;
@@ -2199,8 +2199,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
constantFalse = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
- rewriter.updateRootInPlace(use.getOwner(),
- [&]() { use.set(constantFalse); });
+ rewriter.modifyOpInPlace(use.getOwner(),
+ [&]() { use.set(constantFalse); });
}
}
@@ -2383,14 +2383,14 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
llvm::make_early_inc_range(std::get<0>(it).getUses())) {
if (nextThen && nextThen->getParent()->isAncestor(
use.getOwner()->getParentRegion())) {
- rewriter.startRootUpdate(use.getOwner());
+ rewriter.startOpModification(use.getOwner());
use.set(std::get<1>(it));
- rewriter.finalizeRootUpdate(use.getOwner());
+ rewriter.finalizeOpModification(use.getOwner());
} else if (nextElse && nextElse->getParent()->isAncestor(
use.getOwner()->getParentRegion())) {
- rewriter.startRootUpdate(use.getOwner());
+ rewriter.startOpModification(use.getOwner());
use.set(std::get<2>(it));
- rewriter.finalizeRootUpdate(use.getOwner());
+ rewriter.finalizeOpModification(use.getOwner());
}
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index dc3c46bf896a9c..90f935d71c2fe9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -688,7 +688,7 @@ struct ForOpInterface
yieldValues.push_back(*alloc);
}
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
return success();
}
@@ -928,7 +928,7 @@ struct WhileOpInterface
return failure();
beforeYieldValues.push_back(*alloc);
}
- rewriter.updateRootInPlace(conditionOp, [&]() {
+ rewriter.modifyOpInPlace(conditionOp, [&]() {
conditionOp.getArgsMutable().assign(beforeYieldValues);
});
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 7b6b07eabf6c48..cda561b1d1054d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -89,8 +89,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
SmallVector<Value> yieldOperands = yieldOp.getOperands();
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
- rewriter.updateRootInPlace(
- yieldOp, [&]() { yieldOp->setOperands(yieldOperands); });
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->setOperands(yieldOperands); });
}
// We cannot do a direct replacement of the forOp since the while op returns
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index eee0791b397ae6..c6d024c462e837 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -99,7 +99,7 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
return failure();
Value initArg = forOp.getTiedLoopInit(blockArg)->get();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
@@ -141,7 +141,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
unsigned resultNumber = opResult.getResultNumber();
if (!isShapePreserving(forOp, resultNumber))
return failure();
- rewriter.updateRootInPlace(dimOp, [&]() {
+ rewriter.modifyOpInPlace(dimOp, [&]() {
dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
});
return success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 342213507486af..a5bff0a892c3df 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -160,8 +160,8 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
partialIteration.getInitArgsMutable().assign(forOp->getResults());
// Set new upper loop bound.
- b.updateRootInPlace(
- forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+ b.modifyOpInPlace(forOp,
+ [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
return success();
}
@@ -239,7 +239,7 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map));
// Update main loop with new lower bound.
- b.updateRootInPlace(forOp, [&]() {
+ b.modifyOpInPlace(forOp, [&]() {
forOp.getInitArgsMutable().assign(firstIteration->getResults());
forOp.getLowerBoundMutable().assign(splitBound);
});
@@ -286,11 +286,11 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
}
// Apply label, so that the same loop is not rewritten a second time.
- rewriter.updateRootInPlace(partialIteration, [&]() {
+ rewriter.modifyOpInPlace(partialIteration, [&]() {
partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
});
- rewriter.updateRootInPlace(forOp, [&]() {
+ rewriter.modifyOpInPlace(forOp, [&]() {
forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
});
return success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
index 8c2c544a89f7de..5aa35e79babfce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
@@ -111,7 +111,7 @@ class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> {
return failure();
// Convert operands.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
@@ -131,7 +131,7 @@ class ConvertTypesInSCFConditionOp
return failure();
// Convert operands.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
return success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 7932c38a3e8d8b..e2cc5b4c5ff49b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -241,7 +241,7 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
for (Value operand : adaptor.getOperands())
unpackUnrealizedConversionCast(operand, unpackedYield);
- rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); });
+ rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
return success();
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 38e0625d7ce093..5c9b5281468fc7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -692,7 +692,7 @@ void mlir::scf::yieldReplacementForFusedProducer(
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
unsigned resultNumber = fusableProducer.getResultNumber();
- rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
+ rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index c22cb6710a7e5d..354db6467a582b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -91,8 +91,8 @@ class SPIRVPassThroughConversion : public OpConversionPattern<OpT> {
LogicalResult
matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 9f2755da092293..6150b5ee17851d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -261,7 +261,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
return failure();
// Creates a new function with the update signature.
- rewriter.updateRootInPlace(funcOp, [&] {
+ rewriter.modifyOpInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), std::nullopt));
});
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index c8e77f7de48300..d33eb9d2877ae3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -29,7 +29,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
- rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
+ rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
cloned->getOpResult(0).setType(srcCOOTp);
});
Value srcCOO = cloned->getOpResult(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 50713be8296fa8..a0f7b55ce4446f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -389,14 +389,14 @@ struct GenericOpReinterpretMap
auto stt = tryGetSparseTensorType(res);
auto [idxMap, itTp] = *transMap;
- rewriter.startRootUpdate(linalgOp);
+ rewriter.startOpModification(linalgOp);
linalgOp.setIndexingMapsAttr(idxMap);
linalgOp.setIteratorTypesAttr(itTp);
// Use demapped arguments.
linalgOp.getInputsMutable().assign(adaptor.getInputs());
linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
res.setType(adaptor.getOutputs()[0].getType());
- rewriter.finalizeRootUpdate(linalgOp);
+ rewriter.finalizeOpModification(linalgOp);
rewriter.setInsertionPointAfter(linalgOp);
if (stt && stt->hasEncoding()) {
@@ -458,7 +458,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
}
// Marks the GenericOp to avoid recursive matching.
- rewriter.updateRootInPlace(linalgOp, [&]() {
+ rewriter.modifyOpInPlace(linalgOp, [&]() {
linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
});
@@ -482,10 +482,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
for (AffineMap &idxMap : idxMaps)
idxMap = idxMap.compose(order); // sorted loop -> lvl map
- rewriter.startRootUpdate(linalgOp);
+ rewriter.startOpModification(linalgOp);
linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
- rewriter.finalizeRootUpdate(linalgOp);
+ rewriter.finalizeOpModification(linalgOp);
return success();
}
@@ -570,7 +570,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
rewriter.setInsertionPoint(linalgOp);
RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
- rewriter.updateRootInPlace(linalgOp, [&]() {
+ rewriter.modifyOpInPlace(linalgOp, [&]() {
linalgOp->setOperand(t->getOperandNumber(), dst);
});
return success();
@@ -623,10 +623,10 @@ struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
}
assert(dynSz.empty()); // should have consumed all.
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
op->setOperands(dynLvlSzs);
op.getResult().setType(stt.getDemappedType());
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
rewriter.setInsertionPointAfter(op);
Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
@@ -676,7 +676,7 @@ struct ForeachOpDemapper
auto srcStt = getSparseTensorType(op.getTensor());
SmallVector<Type> prevRetTps(op.getResultTypes());
- rewriter.startRootUpdate(op);
+ rewriter.startOpModification(op);
op.getTensorMutable().assign(adaptor.getTensor());
op.getInitArgsMutable().assign(adaptor.getInitArgs());
// Update results' types.
@@ -731,7 +731,7 @@ struct ForeachOpDemapper
rewriter.eraseOp(yield);
}
}
- rewriter.finalizeRootUpdate(op);
+ rewriter.finalizeOpModification(op);
rewriter.setInsertionPointAfter(op);
SmallVector<Value> outs =
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fa97e405584791..b1b8b762d164d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -329,7 +329,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
.getCopy();
AllocTensorOp a =
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
- rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
+ rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
}
// Replace consumer with fused operation. Old producer
// and consumer ops will be removed by DCE.
@@ -366,7 +366,7 @@ struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
if (Operation *def = op.getSource().getDefiningOp()) {
if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
- rewriter.updateRootInPlace(def, [&]() {
+ rewriter.modifyOpInPlace(def, [&]() {
def->getResult(0).setType(op->getResultTypes()[0]);
});
rewriter.replaceOp(op, def->getResult(0));
@@ -804,7 +804,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
- rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
+ rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
return success();
}
if (encDst) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 7710a44a7ca052..3a487a3bd6a069 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -545,7 +545,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
- rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
+ rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
rewriter.setInsertionPoint(yield);
}
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5834426cae2f41..fec23d2a72347f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -583,7 +583,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
if (def->getBlock() == block) {
rewriter.setInsertionPoint(def);
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
- rewriter.updateRootInPlace(def, [&]() {
+ rewriter.modifyOpInPlace(def, [&]() {
def->setOperand(
i, relinkBranch(env, rewriter, block, def->getOperand(i)));
});
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 80dad064676220..3d8cc5222b828b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -1416,7 +1416,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
Operation *newRed = rewriter.clone(*redExp);
// Replaces arguments of the reduction expression by using the block
// arguments from scf.reduce.
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
// Erases the out-dated reduction expression.
rewriter.eraseOp(redExp);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 816e6ba8fed94e..b2fe58099b2fb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -819,7 +819,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
auto resultIndex = source.cast<OpResult>().getResultNumber();
auto initOperand = destOp.getDpsInitOperand(resultIndex);
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
return success();
}
@@ -1752,7 +1752,7 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
srcType, collapseShapeOp.getReassociationMaps());
if (newResultType == collapseShapeOp.getResultType()) {
- rewriter.updateRootInPlace(collapseShapeOp, [&]() {
+ rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
collapseShapeOp.getSrcMutable().assign(castOp.getSource());
});
} else {
@@ -2930,7 +2930,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
- rewriter.updateRootInPlace(padTensorOp, [&]() {
+ rewriter.modifyOpInPlace(padTensorOp, [&]() {
padTensorOp.getSourceMutable().assign(castOp.getSource());
});
} else {
@@ -3994,9 +3994,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Fold optional PaddingValue operand away if padding is not needed.
if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
- rewriter.startRootUpdate(packOp);
+ rewriter.startOpModification(packOp);
packOp.getPaddingValueMutable().clear();
- rewriter.finalizeRootUpdate(packOp);
+ rewriter.finalizeOpModification(packOp);
return success();
}
return failure();
@@ -4166,8 +4166,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
auto destValue = unPackOp.getDest().cast<OpResult>();
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
- rewriter.updateRootInPlace(
- unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
+ rewriter.modifyOpInPlace(unPackOp,
+ [&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
return failure();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 26c39ff3523434..744ab4154fe8a9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -66,7 +66,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
- rewriter.updateRootInPlace(op, [&]() {
+ rewriter.modifyOpInPlace(op, [&]() {
op.getOperation()->setOperands(
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
});
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f257728a7b947c..749eb56b3d3bec 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4416,7 +4416,7 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueWAW(writeOp, defWrite)) {
- rewriter.updateRootInPlace(writeToModify, [&]() {
+ rewriter.modifyOpInPlace(writeToModify, [&]() {
writeToModify.getSourceMutable().assign(defWrite.getSource());
});
return success();
@@ -4533,7 +4533,7 @@ struct SwapExtractSliceOfTransferWrite
transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
transferOp.getIndices(), transferOp.getPermutationMapAttr(),
rewriter.getBoolArrayAttr(newInBounds));
- rewriter.updateRootInPlace(insertOp, [&]() {
+ rewriter.modifyOpInPlace(insertOp, [&]() {
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
});
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 5782ee1d58cf53..1caec5bb8644f3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -225,7 +225,7 @@ struct MaskOpInterface
newReturnValues[it.index()] = it.value();
}
}
- rewriter.updateRootInPlace(yieldOp, [&]() {
+ rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getOperandsMutable().assign(newYieldedValues);
});
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 9d5ad20d4715b1..620ceee48b196d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -182,7 +182,7 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
auto yield =
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
return newWarpOp;
}
@@ -724,7 +724,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
- rewriter.startRootUpdate(warpOp);
+ rewriter.startOpModification(warpOp);
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
auto newAttr = DenseElementsAttr::get(
@@ -733,7 +733,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
- rewriter.finalizeRootUpdate(warpOp);
+ rewriter.finalizeOpModification(warpOp);
return success();
}
};
@@ -1017,9 +1017,9 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
- rewriter.startRootUpdate(warpOp);
+ rewriter.startOpModification(warpOp);
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
- rewriter.finalizeRootUpdate(warpOp);
+ rewriter.finalizeOpModification(warpOp);
return success();
}
};
@@ -1159,7 +1159,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
- rewriter.startRootUpdate(warpOp);
+ rewriter.startOpModification(warpOp);
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
@@ -1179,7 +1179,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto newMask =
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
- rewriter.finalizeRootUpdate(warpOp);
+ rewriter.finalizeOpModification(warpOp);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ea33453e7215e3..f1a27168bd4e54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -525,7 +525,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
auto inBoundsAttr = b.getBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
- b.updateRootInPlace(xferOp, [&]() {
+ b.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
});
return success();
@@ -598,7 +598,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
- b.updateRootInPlace(xferOp, [&]() {
+ b.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
});
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 661674dd74c0cd..bd02c07981466d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1050,7 +1050,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
}
- rewriter.updateRootInPlace(xferOp, [&]() {
+ rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(mask);
xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
});
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5e788cdb4897d3..73f232fd0de01a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -263,7 +263,7 @@ void RewriterBase::eraseBlock(Block *block) {
block->erase();
}
-void RewriterBase::finalizeRootUpdate(Operation *op) {
+void RewriterBase::finalizeOpModification(Operation *op) {
// Notify the listener that the operation was modified.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationModified(op);
@@ -276,7 +276,7 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
if (functor(operand))
- updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); });
+ modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
}
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 26a7ea5d5e219e..f3a973d9994083 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -506,7 +506,7 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
if (info.mergePoints.contains(blockOperand.get())) {
if (!job.reachingDef)
job.reachingDef = getLazyDefaultValue();
- rewriter.updateRootInPlace(terminator, [&]() {
+ rewriter.modifyOpInPlace(terminator, [&]() {
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
.append(job.reachingDef);
});
@@ -596,7 +596,7 @@ void MemorySlotPromoter::promoteSlot() {
assert(succOperands.size() == mergePoint->getNumArguments() ||
succOperands.size() + 1 == mergePoint->getNumArguments());
if (succOperands.size() + 1 == mergePoint->getNumArguments())
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
user, [&]() { succOperands.append(getLazyDefaultValue()); });
}
}
diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
index 6034366631d10f..5ba6e4747cb57f 100644
--- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
+++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp
@@ -304,7 +304,7 @@ class SortCommutativeOperands : public RewritePattern {
sortedOperands.push_back(commOperand->operand);
if (sortedOperands == operands)
return failure();
- rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
+ rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
return success();
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 85433d088dcbf0..ef6a49455d1860 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1614,15 +1614,15 @@ void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
impl->createdOps.push_back(op);
}
-void ConversionPatternRewriter::startRootUpdate(Operation *op) {
+void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
impl->rootUpdates.emplace_back(op);
}
-void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
- PatternRewriter::finalizeRootUpdate(op);
+void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
+ PatternRewriter::finalizeOpModification(op);
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
@@ -1631,7 +1631,7 @@ void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
#endif
}
-void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
+void ConversionPatternRewriter::cancelOpModification(Operation *op) {
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
@@ -3115,7 +3115,7 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
auto newType = FunctionType::get(rewriter.getContext(),
result.getConvertedTypes(), newResults);
- rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
+ rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2e3bc76009ca20..d1ac5e81e75a69 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -160,7 +160,7 @@ struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
int64_t val = intAttr.getInt();
if (val >= MaxVal)
return failure();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
return success();
}
@@ -175,7 +175,7 @@ struct MakeOpEligible : public RewritePattern {
PatternRewriter &rewriter) const override {
if (op->hasAttr("eligible"))
return failure();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
return success();
}
@@ -195,7 +195,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
return failure();
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
// op is modified.
- rewriter.updateRootInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
+ rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
return success();
}
};
@@ -327,7 +327,7 @@ struct TestStrictPatternDriver
Operation *newOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
op->getOperands(), op->getResultTypes());
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
newOp->setAttr("skip", rewriter.getBoolAttr(true));
@@ -415,8 +415,8 @@ struct TestStrictPatternDriver
PatternRewriter &rewriter) const override {
if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
return failure();
- rewriter.updateRootInPlace(
- op, [&]() { op->setSuccessor(op->getBlock(), 0); });
+ rewriter.modifyOpInPlace(op,
+ [&]() { op->setSuccessor(op->getBlock(), 0); });
return success();
}
};
@@ -650,7 +650,7 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
illegalOp->getResult(0));
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -667,7 +667,7 @@ struct TestUndoBlockErase : public ConversionPattern {
rewriter.setInsertionPointToStart(secondBlock);
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
rewriter.eraseBlock(secondBlock);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -827,7 +827,7 @@ struct TestBoundedRecursiveRewrite
LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
PatternRewriter &rewriter) const final {
// Decrement the depth of the op in-place.
- rewriter.updateRootInPlace(op, [&] {
+ rewriter.modifyOpInPlace(op, [&] {
op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
});
return success();
@@ -1333,7 +1333,7 @@ struct TestTestSignatureConversionNoConverter
if (failed(
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
return failure();
- rewriter.updateRootInPlace(
+ rewriter.modifyOpInPlace(
op, [&] { rewriter.applySignatureConversion(®ion, result); });
return success();
}
@@ -1350,8 +1350,8 @@ struct TestTypeConsumerForward
LogicalResult
matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- rewriter.updateRootInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.modifyOpInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
@@ -1567,7 +1567,7 @@ struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
SmallVector<Value, 2> replacements(succOperands);
rewriter.eraseOp(branchOp);
rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -1588,7 +1588,7 @@ struct TestUndoBlocksMerge : public ConversionPattern {
SmallVector<Value, 2> replacements(succOperands);
rewriter.eraseOp(branchOp);
rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
@@ -1613,7 +1613,7 @@ struct TestMergeSingleBlockOps
rewriter.inlineBlockBefore(&innerBlock, op);
rewriter.eraseOp(innerTerminator);
rewriter.eraseOp(op);
- rewriter.updateRootInPlace(op, [] {});
+ rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
More information about the Mlir-commits
mailing list