[Mlir-commits] [flang] [mlir] [mlir][IR] Add rewriter API for moving operations (PR #78988)
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 25 01:16:12 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/78988
>From 6161d337181396db48740d411a7c28c95482d232 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 22 Jan 2024 15:42:02 +0000
Subject: [PATCH] [mlir][IR] Add rewriter API for moving operations
The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`.
After an operation was moved, the `notifyOperationInserted` callback is triggered. This may cause listeners such as the greedy pattern rewrite driver to put the op back on the worklist.
---
.../flang/Optimizer/Builder/FIRBuilder.h | 6 +++-
.../HLFIR/Transforms/BufferizeHLFIR.cpp | 7 +++--
.../toy/Ch5/mlir/LowerToAffineLoops.cpp | 4 +--
.../toy/Ch6/mlir/LowerToAffineLoops.cpp | 4 +--
.../toy/Ch7/mlir/LowerToAffineLoops.cpp | 4 +--
mlir/include/mlir/IR/Builders.h | 20 ++++++++-----
mlir/include/mlir/IR/PatternMatch.h | 26 +++++++++++++++--
.../mlir/Transforms/DialectConversion.h | 8 +++++-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 +--
.../Async/Transforms/AsyncParallelFor.cpp | 4 +--
.../Bufferization/Transforms/Bufferize.cpp | 6 +++-
.../TransformOps/LinalgTransformOps.cpp | 8 ++++--
.../lib/Dialect/SCF/Transforms/ForToWhile.cpp | 2 +-
.../BufferizableOpInterfaceImpl.cpp | 2 +-
mlir/lib/IR/Builders.cpp | 4 +--
mlir/lib/IR/PatternMatch.cpp | 28 +++++++++++++++++++
.../Transforms/Utils/DialectConversion.cpp | 18 ++++++++++--
.../Utils/GreedyPatternRewriteDriver.cpp | 18 ++++++++----
.../Utils/LoopInvariantCodeMotionUtils.cpp | 4 +--
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 4 +--
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 4 +--
mlir/test/lib/IR/TestClone.cpp | 3 +-
mlir/test/lib/Transforms/TestConstantFold.cpp | 3 +-
23 files changed, 144 insertions(+), 47 deletions(-)
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index b5b2c99810b15bb..c9dcfde56e3829d 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -490,7 +490,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
LLVM_DUMP_METHOD void dumpFunc();
/// FirOpBuilder hook for creating new operation.
- void notifyOperationInserted(mlir::Operation *op) override {
+ void notifyOperationInserted(mlir::Operation *op,
+ mlir::OpBuilder::InsertPoint previous) override {
+ // We only care about newly created operations.
+ if (!previous.isSet())
+ return;
setCommonAttributes(op);
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 641854bd201f0b0..5fe78b7408026f8 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
HLFIRListener(fir::FirOpBuilder &builder,
mlir::ConversionPatternRewriter &rewriter)
: builder{builder}, rewriter{rewriter} {}
- void notifyOperationInserted(mlir::Operation *op) override {
- builder.notifyOperationInserted(op);
- rewriter.notifyOperationInserted(op);
+ void notifyOperationInserted(mlir::Operation *op,
+ mlir::OpBuilder::InsertPoint previous) override {
+ builder.notifyOperationInserted(op, previous);
+ rewriter.notifyOperationInserted(op, previous);
}
virtual void notifyBlockCreated(mlir::Block *block) override {
builder.notifyBlockCreated(block);
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
- alloc->moveBefore(&parentBlock->front());
+ rewriter.moveOpBefore(alloc, &parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
- dealloc->moveBefore(&parentBlock->back());
+ rewriter.moveOpBefore(dealloc, &parentBlock->back());
return alloc;
}
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
- alloc->moveBefore(&parentBlock->front());
+ rewriter.moveOpBefore(alloc, &parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
- dealloc->moveBefore(&parentBlock->back());
+ rewriter.moveOpBefore(dealloc, &parentBlock->back());
return alloc;
}
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
- alloc->moveBefore(&parentBlock->front());
+ rewriter.moveOpBefore(alloc, &parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
- dealloc->moveBefore(&parentBlock->back());
+ rewriter.moveOpBefore(dealloc, &parentBlock->back());
return alloc;
}
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 13fbc3fb928c399..6b95be7c6d372f8 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -205,6 +205,7 @@ class Builder {
/// automatically inserted at an insertion point. The builder is copyable.
class OpBuilder : public Builder {
public:
+ class InsertPoint;
struct Listener;
/// Create a builder with the given context.
@@ -285,12 +286,17 @@ class OpBuilder : public Builder {
virtual ~Listener() = default;
- /// Notification handler for when an operation is inserted into the builder.
- /// `op` is the operation that was inserted.
- virtual void notifyOperationInserted(Operation *op) {}
-
- /// Notification handler for when a block is created using the builder.
- /// `block` is the block that was created.
+ /// Notify the listener that the specified operation was inserted.
+ ///
+ /// * If the operation was moved, then `previous` is the previous location
+ /// of the op.
+ /// * If the operation was unlinked before it was inserted, then `previous`
+ /// is empty.
+ ///
+ /// Note: Creating an (unlinked) op does not trigger this notification.
+ virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
+
+ /// Notify the listener that the specified block was inserted.
virtual void notifyBlockCreated(Block *block) {}
protected:
@@ -517,7 +523,7 @@ class OpBuilder : public Builder {
if (succeeded(tryFold(op, results)))
op->erase();
else if (listener)
- listener->notifyOperationInserted(op);
+ listener->notifyOperationInserted(op, /*previous=*/{});
}
/// Overload to create or fold a single result operation.
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 815340c91850935..7f233cd3f4d4b3c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
+ ///
+ /// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationRemoved(Operation *op) {}
/// Notify the listener that the pattern failed to match the given
@@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder {
struct ForwardingListener : public RewriterBase::Listener {
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
- void notifyOperationInserted(Operation *op) override {
- listener->notifyOperationInserted(op);
+ void notifyOperationInserted(Operation *op, InsertPoint previous) override {
+ listener->notifyOperationInserted(op, previous);
}
void notifyBlockCreated(Block *block) override {
listener->notifyBlockCreated(block);
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
+ /// Unlink this operation from its current block and insert it right before
+ /// `existingOp` which may be in the same or another block in the same
+ /// function.
+ void moveOpBefore(Operation *op, Operation *existingOp);
+
+ /// Unlink this operation from its current block and insert it right before
+ /// `iterator` in the specified block.
+ virtual void moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator);
+
+ /// Unlink this operation from its current block and insert it right after
+ /// `existingOp` which may be in the same or another block in the same
+ /// function.
+ void moveOpAfter(Operation *op, Operation *existingOp);
+
+ /// Unlink this operation from its current block and insert it right after
+ /// `iterator` in the specified block.
+ virtual void moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator);
+
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9568540789df3f6..32c5937d014e9ef 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
using PatternRewriter::cloneRegionBefore;
/// PatternRewriter hook for inserting a new operation.
- void notifyOperationInserted(Operation *op) override;
+ void notifyOperationInserted(Operation *op, InsertPoint previous) override;
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
detail::ConversionPatternRewriterImpl &getImpl();
private:
+ // Hide unsupported pattern rewriter API.
using OpBuilder::getListener;
using OpBuilder::setListener;
+ void moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator) override;
+ void moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator) override;
+
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c260e68d509e983..b802ae33edaccee 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
if (failed(applyOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
- listener->notifyOperationInserted(applyOp);
+ listener->notifyOperationInserted(applyOp, /*previous=*/{});
return applyOp.getResult();
}
@@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
- listener->notifyOperationInserted(minMaxOp);
+ listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
return minMaxOp.getResult();
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 428a3c945581b48..8c3e25355f6087e 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
- rewriter.getListener()->notifyOperationInserted(func);
+ rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
// Create function entry block.
Block *block =
@@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
- rewriter.getListener()->notifyOperationInserted(func);
+ rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
// Create function entry block.
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 3f1626a6af34d49..2758d554712b9f4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -371,7 +371,11 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
toMemrefOps.erase(op);
}
- void notifyOperationInserted(Operation *op) override {
+ void notifyOperationInserted(Operation *op, InsertPoint previous) override {
+ // We only care about newly created ops.
+ if (previous.isSet())
+ return;
+
erasedOps.erase(op);
// Gather statistics about allocs.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 140bdd1f2db3618..803c5691a0403b0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -214,8 +214,12 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
}
private:
- void notifyOperationInserted(Operation *op) override {
- ForwardingListener::notifyOperationInserted(op);
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ ForwardingListener::notifyOperationInserted(op, previous);
+ // We only care about newly created ops.
+ if (previous.isSet())
+ return;
auto inserted = newOps.insert(op);
(void)inserted;
assert(inserted.second && "expected newly created op");
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index cda561b1d1054d9..9f8189ae15e6de2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
// Inline for-loop body operations into 'after' region.
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
- arg.moveBefore(afterBlock, afterBlock->end());
+ rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
// Add incremented IV to yield operations
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2cd57e7324b4dc5..678b7c099fa3692 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
for (Operation *user : srcBuffer->getUsers()) {
if (hasEffect<MemoryEffects::Free>(user)) {
if (user->getBlock() == parallelCombiningParent->getBlock())
- user->moveBefore(user->getBlock()->getTerminator());
+ rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
break;
}
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d1565047658776d..a319afcdc6a9a23 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
block->getOperations().insert(insertPoint, op);
if (listener)
- listener->notifyOperationInserted(op);
+ listener->notifyOperationInserted(op, /*previous=*/{});
return op;
}
@@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
auto walkFn = [&](Operation *walkedOp) {
- listener->notifyOperationInserted(walkedOp);
+ listener->notifyOperationInserted(walkedOp, /*previous=*/{});
};
for (Region ®ion : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index ba0516e0539b6ca..affb8898fa07544 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
+
+void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
+ moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator) {
+ Block *currentBlock = op->getBlock();
+ Block::iterator currentIterator = op->getIterator();
+ op->moveBefore(block, iterator);
+ if (listener)
+ listener->notifyOperationInserted(
+ op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
+ moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator) {
+ Block *currentBlock = op->getBlock();
+ Block::iterator currentIterator = op->getIterator();
+ op->moveAfter(block, iterator);
+ if (listener)
+ listener->notifyOperationInserted(
+ op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ef6a49455d18605..f5bede2b94f9cb2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region ®ion,
Block *cloned = mapping.lookup(&b);
impl->notifyCreatedBlock(cloned);
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
- [&](Operation *op) { notifyOperationInserted(op); });
+ [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
}
}
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
+void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
+ InsertPoint previous) {
+ assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
impl->logger.startLine()
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
@@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
return impl->notifyMatchFailure(loc, reasonCallback);
}
+void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator) {
+ llvm_unreachable(
+ "moving single ops is not supported in a dialect conversion");
+}
+
+void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator) {
+ llvm_unreachable(
+ "moving single ops is not supported in a dialect conversion");
+}
+
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index ac73e82bfe92a69..c27fee7a738eba0 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -133,9 +133,16 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
}
}
- void notifyOperationInserted(Operation *op) override {
- RewriterBase::ForwardingListener::notifyOperationInserted(op);
+ void notifyOperationInserted(Operation *op, InsertPoint previous) override {
+ RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
+ // Invalidate the finger print of the op that owns the block into which the
+ // op was inserted into.
invalidateFingerPrint(op->getParentOp());
+
+ // Also invalidate the finger print of the op that owns the block from which
+ // the op was moved from. (Only applicable if the op was moved.)
+ if (previous.isSet())
+ invalidateFingerPrint(previous.getBlock()->getParentOp());
}
void notifyOperationModified(Operation *op) override {
@@ -331,7 +338,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
- void notifyOperationInserted(Operation *op) override;
+ void notifyOperationInserted(Operation *op, InsertPoint previous) override;
/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
@@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
config.listener->notifyBlockRemoved(block);
}
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
+void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
+ InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
- config.listener->notifyOperationInserted(op);
+ config.listener->notifyOperationInserted(op, previous);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 8f97fd3d9ddf84e..66ce6067963f838 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
OpResult newLoopResult = loopLike.getLoopResults()->back();
- extractionOp->moveBefore(loopLike);
- insertionOp->moveAfter(loopLike);
+ rewriter.moveOpBefore(extractionOp, loopLike);
+ rewriter.moveOpAfter(insertionOp, loopLike);
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index a8a808424b690f6..8a92d840ad13026 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
auto ifOp =
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
// True branch.
- op->moveBefore(&ifOp.getThenRegion().front(),
- ifOp.getThenRegion().front().begin());
+ rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
+ ifOp.getThenRegion().front().begin());
rewriter.setInsertionPointAfter(op);
if (op->getNumResults() > 0)
rewriter.create<scf::YieldOp>(loc, op->getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d1ac5e81e75a695..89b9d1ce78a52b6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
return failure();
if (!toBeHoisted->hasAttr("eligible"))
return failure();
- // Hoisting means removing an op from the enclosing op. I.e., the enclosing
- // op is modified.
- rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
+ rewriter.moveOpBefore(toBeHoisted, op);
return success();
}
};
diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp
index 13a0cfeb402a9cd..7b18f219b915f46 100644
--- a/mlir/test/lib/IR/TestClone.cpp
+++ b/mlir/test/lib/IR/TestClone.cpp
@@ -15,7 +15,8 @@ using namespace mlir;
namespace {
struct DumpNotifications : public OpBuilder::Listener {
- void notifyOperationInserted(Operation *op) override {
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
}
};
diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp
index aa67e0a78d43b7a..81f634d8d3ef78f 100644
--- a/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -27,7 +27,8 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
void foldOperation(Operation *op, OperationFolder &helper);
void runOnOperation() override;
- void notifyOperationInserted(Operation *op) override {
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
existingConstants.push_back(op);
}
void notifyOperationRemoved(Operation *op) override {
More information about the Mlir-commits
mailing list