[Mlir-commits] [mlir] c653283 - This change makes `RewriterBase` symmetric to `OpBuilder`.
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 22 00:18:55 PST 2023
Author: Matthias Springer
Date: 2023-02-22T09:18:27+01:00
New Revision: c65328305e98a806ae0eb811c7f17e3c5b0c0158
URL: https://github.com/llvm/llvm-project/commit/c65328305e98a806ae0eb811c7f17e3c5b0c0158
DIFF: https://github.com/llvm/llvm-project/commit/c65328305e98a806ae0eb811c7f17e3c5b0c0158.diff
LOG: This change makes `RewriterBase` symmetric to `OpBuilder`.
```
OpBuilder OpBuilder::Listener
^ ^
| |
RewriterBase RewriterBase::Listener
```
* Clients can listen to IR modifications with `RewriterBase::Listener`.
* `RewriterBase` no longer inherits from `OpBuilder::Listener`.
* Only a single listener can be registered at the moment (same as `OpBuilder`).
RFC: https://discourse.llvm.org/t/rfc-listeners-for-rewriterbase/68198
Differential Revision: https://reviews.llvm.org/D143339
Added:
Modified:
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 14df7b09032a1..f970d89dd410f 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -253,10 +253,32 @@ class OpBuilder : public Builder {
// Listeners
//===--------------------------------------------------------------------===//
+ /// Base class for listeners.
+ struct ListenerBase {
+ /// The kind of listener.
+ enum class Kind {
+ /// OpBuilder::Listener or user-derived class.
+ OpBuilderListener = 0,
+
+ /// RewriterBase::Listener or user-derived class.
+ RewriterBaseListener = 1
+ };
+
+ Kind getKind() const { return kind; }
+
+ protected:
+ ListenerBase(Kind kind) : kind(kind) {}
+
+ private:
+ const Kind kind;
+ };
+
/// This class represents a listener that may be used to hook into various
/// actions within an OpBuilder.
- struct Listener {
- virtual ~Listener();
+ struct Listener : public ListenerBase {
+ Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {}
+
+ virtual ~Listener() = default;
/// Notification handler for when an operation is inserted into the builder.
/// `op` is the operation that was inserted.
@@ -265,6 +287,9 @@ class OpBuilder : public Builder {
/// Notification handler for when a block is created using the builder.
/// `block` is the block that was created.
virtual void notifyBlockCreated(Block *block) {}
+
+ protected:
+ Listener(Kind kind) : ListenerBase(kind) {}
};
/// Sets the listener of this builder to the one provided.
@@ -537,14 +562,16 @@ class OpBuilder : public Builder {
return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
}
+protected:
+ /// The optional listener for events of this builder.
+ Listener *listener;
+
private:
/// The current block this builder is inserting into.
Block *block = nullptr;
/// The insertion point within the block that this builder is inserting
/// before.
Block::iterator insertPoint;
- /// The optional listener for events of this builder.
- Listener *listener;
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 187ce060f7ebb..7845cb19f6125 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -396,8 +396,36 @@ class OpTraitRewritePattern : public RewritePattern {
/// This class serves as a common API for IR mutation between pattern rewrites
/// and non-pattern rewrites, and facilitates the development of shared
/// IR transformation utilities.
-class RewriterBase : public OpBuilder, public OpBuilder::Listener {
+class RewriterBase : public OpBuilder {
public:
+ struct Listener : public OpBuilder::Listener {
+ Listener()
+ : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
+
+ /// Notify the listener that the specified operation is about to be replaced
+ /// with the set of values potentially produced by new operations. This is
+ /// called before the uses of the operation have been changed.
+ virtual void notifyOperationReplaced(Operation *op,
+ ValueRange replacement) {}
+
+ /// This is called on an operation that a rewrite is removing, right before
+ /// the operation is deleted. At this point, the operation has zero uses.
+ virtual void notifyOperationRemoved(Operation *op) {}
+
+ /// Notify the listener that the pattern failed to match the given
+ /// operation, and provide a callback to populate a diagnostic with the
+ /// reason why the failure occurred. This method allows for derived
+ /// listeners to optionally hook into the reason why a rewrite failed, and
+ /// display it to users.
+ virtual LogicalResult
+ notifyMatchFailure(Location loc,
+ function_ref<void(Diagnostic &)> reasonCallback) {
+ return failure();
+ }
+
+ static bool classof(const OpBuilder::Listener *base);
+ };
+
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be
diff erent. The caller
/// is responsible for creating or updating the operation transferring flow
@@ -541,8 +569,10 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
#ifndef NDEBUG
- return notifyMatchFailure(loc,
- function_ref<void(Diagnostic &)>(reasonCallback));
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ return rewriteListener->notifyMatchFailure(
+ loc, function_ref<void(Diagnostic &)>(reasonCallback));
+ return failure();
#else
return failure();
#endif
@@ -550,8 +580,10 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
template <typename CallbackT>
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
- return notifyMatchFailure(op->getLoc(),
- function_ref<void(Diagnostic &)>(reasonCallback));
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ return rewriteListener->notifyMatchFailure(
+ op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
+ return failure();
}
template <typename ArgT>
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
@@ -564,35 +596,11 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
}
protected:
- /// Initialize the builder with this rewriter as the listener.
- explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
+ /// Initialize the builder.
+ explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx) {}
explicit RewriterBase(const OpBuilder &otherBuilder)
- : OpBuilder(otherBuilder) {
- setListener(this);
- }
- ~RewriterBase() override;
-
- /// These are the callback methods that subclasses can choose to implement if
- /// they would like to be notified about certain types of mutations.
-
- /// Notify the rewriter that the specified operation is about to be replaced
- /// with the set of values potentially produced by new operations. This is
- /// called before the uses of the operation have been changed.
- virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {}
-
- /// This is called on an operation that a rewrite is removing, right before
- /// the operation is deleted. At this point, the operation has zero uses.
- virtual void notifyOperationRemoved(Operation *op) {}
-
- /// Notify the rewriter that the pattern failed to match the given operation,
- /// and provide a callback to populate a diagnostic with the reason why the
- /// failure occurred. This method allows for derived rewriters to optionally
- /// hook into the reason why a rewrite failed, and display it to users.
- virtual LogicalResult
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) {
- return failure();
- }
+ : OpBuilder(otherBuilder) {}
+ virtual ~RewriterBase();
private:
void operator=(const RewriterBase &) = delete;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c592f2db999a2..229dc016957c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -618,7 +618,8 @@ struct ConversionPatternRewriterImpl;
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
-class ConversionPatternRewriter final : public PatternRewriter {
+class ConversionPatternRewriter final : public PatternRewriter,
+ public RewriterBase::Listener {
public:
explicit ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
@@ -742,6 +743,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
detail::ConversionPatternRewriterImpl &getImpl();
private:
+ using OpBuilder::getListener;
+ using OpBuilder::setListener;
+
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index d07204d475b32..25cb61857a10e 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -555,7 +555,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
- ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
+ ImplicitLocOpBuilder builder(loc, op, &rewriter);
builder.create<RuntimeAwaitOp>(loc, operand);
// Assert that the awaited operands is not in the error state.
@@ -574,7 +574,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
CoroMachinery &coro = funcCoro->getSecond();
Block *suspended = op->getBlock();
- ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
+ ImplicitLocOpBuilder builder(loc, op, &rewriter);
MLIRContext *ctx = op->getContext();
// Save the coroutine state and resume on a runtime managed thread when
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 3ec037069c2c2..0b10bafb9f163 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -342,7 +342,7 @@ static bool hasTensorSemantics(Operation *op) {
namespace {
/// A rewriter that keeps track of extra information during bufferization.
-class BufferizationRewriter : public IRRewriter {
+class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
public:
BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
DenseSet<Operation *> &toMemrefOps,
@@ -352,18 +352,18 @@ class BufferizationRewriter : public IRRewriter {
BufferizationStatistics *statistics)
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
worklist(worklist), analysisState(options), opFilter(opFilter),
- statistics(statistics) {}
+ statistics(statistics) {
+ setListener(this);
+ }
protected:
void notifyOperationRemoved(Operation *op) override {
- IRRewriter::notifyOperationRemoved(op);
erasedOps.insert(op);
// Erase if present.
toMemrefOps.erase(op);
}
void notifyOperationInserted(Operation *op) override {
- IRRewriter::notifyOperationInserted(op);
erasedOps.erase(op);
// Gather statistics about allocs and deallocs.
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d36791fef23d1..8eab32b201a04 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -388,8 +388,6 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
// OpBuilder
//===----------------------------------------------------------------------===//
-OpBuilder::Listener::~Listener() = default;
-
/// Insert the given operation at the current insertion point and return it.
Operation *OpBuilder::insert(Operation *op) {
if (block)
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 1ca86cdcba1cc..10baea61d9a4f 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -217,6 +217,10 @@ void PDLPatternModule::registerRewriteFunction(StringRef name,
// RewriterBase
//===----------------------------------------------------------------------===//
+bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
+ return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
+}
+
RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}
@@ -232,7 +236,8 @@ void RewriterBase::replaceOpWithIf(
"incorrect number of values to replace operation");
// Notify the rewriter subclass that we're about to replace this root.
- notifyRootReplaced(op, newValues);
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationReplaced(op, newValues);
// Replace each use of the results when the functor is true.
bool replacedAllUses = true;
@@ -260,13 +265,15 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
/// the operation.
void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
// Notify the rewriter subclass that we're about to replace this root.
- notifyRootReplaced(op, newValues);
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationReplaced(op, newValues);
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
op->replaceAllUsesWith(newValues);
- notifyOperationRemoved(op);
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationRemoved(op);
op->erase();
}
@@ -274,7 +281,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
/// the given operation *must* be known to be dead.
void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
- notifyOperationRemoved(op);
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationRemoved(op);
op->erase();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b82fc580d2ebd..0d78362da7f09 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1495,7 +1495,10 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this)) {}
+ impl(new detail::ConversionPatternRewriterImpl(*this)) {
+ setListener(this);
+}
+
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
void ConversionPatternRewriter::replaceOpWithIf(
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 997bdc6a1c49f..adf8b5121ab9e 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,7 +39,8 @@ namespace {
/// This abstract class manages the worklist and contains helper methods for
/// rewriting ops on the worklist. Derived classes specify how ops are added
/// to the worklist in the beginning.
-class GreedyPatternRewriteDriver : public PatternRewriter {
+class GreedyPatternRewriteDriver : public PatternRewriter,
+ public RewriterBase::Listener {
protected:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
@@ -67,7 +68,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Notify the driver that the specified operation was replaced. Update the
/// worklist as needed: New users are added enqueued.
- void notifyRootReplaced(Operation *op, ValueRange replacement) override;
+ void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
/// Process ops until the worklist is empty or `config.maxNumRewrites` is
/// reached. Return `true` if any IR was changed.
@@ -128,6 +129,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
// Apply a simple cost model based solely on pattern benefit.
matcher.applyDefaultCostModel();
+
+ // Set up listener.
+ setListener(this);
}
bool GreedyPatternRewriteDriver::processWorklist() {
@@ -359,8 +363,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
strictModeFilteredOps.erase(op);
}
-void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
- ValueRange replacement) {
+void GreedyPatternRewriteDriver::notifyOperationReplaced(
+ Operation *op, ValueRange replacement) {
LLVM_DEBUG({
logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
<< ")\n";
More information about the Mlir-commits
mailing list