[Mlir-commits] [mlir] c653283 - This change makes `RewriterBase` symmetric to `OpBuilder`.

Matthias Springer llvmlistbot at llvm.org
Wed Feb 22 00:18:55 PST 2023


Author: Matthias Springer
Date: 2023-02-22T09:18:27+01:00
New Revision: c65328305e98a806ae0eb811c7f17e3c5b0c0158

URL: https://github.com/llvm/llvm-project/commit/c65328305e98a806ae0eb811c7f17e3c5b0c0158
DIFF: https://github.com/llvm/llvm-project/commit/c65328305e98a806ae0eb811c7f17e3c5b0c0158.diff

LOG: This change makes `RewriterBase` symmetric to `OpBuilder`.

```
  OpBuilder           OpBuilder::Listener
      ^                        ^
      |                        |
RewriterBase        RewriterBase::Listener
```

* Clients can listen to IR modifications with `RewriterBase::Listener`.
* `RewriterBase` no longer inherits from `OpBuilder::Listener`.
* Only a single listener can be registered at the moment (same as `OpBuilder`).

RFC: https://discourse.llvm.org/t/rfc-listeners-for-rewriterbase/68198

Differential Revision: https://reviews.llvm.org/D143339

Added: 
    

Modified: 
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 14df7b09032a1..f970d89dd410f 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -253,10 +253,32 @@ class OpBuilder : public Builder {
   // Listeners
   //===--------------------------------------------------------------------===//
 
+  /// Base class for listeners.
+  struct ListenerBase {
+    /// The kind of listener.
+    enum class Kind {
+      /// OpBuilder::Listener or user-derived class.
+      OpBuilderListener = 0,
+
+      /// RewriterBase::Listener or user-derived class.
+      RewriterBaseListener = 1
+    };
+
+    Kind getKind() const { return kind; }
+
+  protected:
+    ListenerBase(Kind kind) : kind(kind) {}
+
+  private:
+    const Kind kind;
+  };
+
   /// This class represents a listener that may be used to hook into various
   /// actions within an OpBuilder.
-  struct Listener {
-    virtual ~Listener();
+  struct Listener : public ListenerBase {
+    Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {}
+
+    virtual ~Listener() = default;
 
     /// Notification handler for when an operation is inserted into the builder.
     /// `op` is the operation that was inserted.
@@ -265,6 +287,9 @@ class OpBuilder : public Builder {
     /// Notification handler for when a block is created using the builder.
     /// `block` is the block that was created.
     virtual void notifyBlockCreated(Block *block) {}
+
+  protected:
+    Listener(Kind kind) : ListenerBase(kind) {}
   };
 
   /// Sets the listener of this builder to the one provided.
@@ -537,14 +562,16 @@ class OpBuilder : public Builder {
     return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
   }
 
+protected:
+  /// The optional listener for events of this builder.
+  Listener *listener;
+
 private:
   /// The current block this builder is inserting into.
   Block *block = nullptr;
   /// The insertion point within the block that this builder is inserting
   /// before.
   Block::iterator insertPoint;
-  /// The optional listener for events of this builder.
-  Listener *listener;
 };
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 187ce060f7ebb..7845cb19f6125 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -396,8 +396,36 @@ class OpTraitRewritePattern : public RewritePattern {
 /// This class serves as a common API for IR mutation between pattern rewrites
 /// and non-pattern rewrites, and facilitates the development of shared
 /// IR transformation utilities.
-class RewriterBase : public OpBuilder, public OpBuilder::Listener {
+class RewriterBase : public OpBuilder {
 public:
+  struct Listener : public OpBuilder::Listener {
+    Listener()
+        : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
+
+    /// Notify the listener that the specified operation is about to be replaced
+    /// with the set of values potentially produced by new operations. This is
+    /// called before the uses of the operation have been changed.
+    virtual void notifyOperationReplaced(Operation *op,
+                                         ValueRange replacement) {}
+
+    /// This is called on an operation that a rewrite is removing, right before
+    /// the operation is deleted. At this point, the operation has zero uses.
+    virtual void notifyOperationRemoved(Operation *op) {}
+
+    /// Notify the listener that the pattern failed to match the given
+    /// operation, and provide a callback to populate a diagnostic with the
+    /// reason why the failure occurred. This method allows for derived
+    /// listeners to optionally hook into the reason why a rewrite failed, and
+    /// display it to users.
+    virtual LogicalResult
+    notifyMatchFailure(Location loc,
+                       function_ref<void(Diagnostic &)> reasonCallback) {
+      return failure();
+    }
+
+    static bool classof(const OpBuilder::Listener *base);
+  };
+
   /// Move the blocks that belong to "region" before the given position in
   /// another region "parent". The two regions must be 
diff erent. The caller
   /// is responsible for creating or updating the operation transferring flow
@@ -541,8 +569,10 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
 #ifndef NDEBUG
-    return notifyMatchFailure(loc,
-                              function_ref<void(Diagnostic &)>(reasonCallback));
+    if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+      return rewriteListener->notifyMatchFailure(
+          loc, function_ref<void(Diagnostic &)>(reasonCallback));
+    return failure();
 #else
     return failure();
 #endif
@@ -550,8 +580,10 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
-    return notifyMatchFailure(op->getLoc(),
-                              function_ref<void(Diagnostic &)>(reasonCallback));
+    if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+      return rewriteListener->notifyMatchFailure(
+          op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
+    return failure();
   }
   template <typename ArgT>
   LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
@@ -564,35 +596,11 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
   }
 
 protected:
-  /// Initialize the builder with this rewriter as the listener.
-  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
+  /// Initialize the builder.
+  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx) {}
   explicit RewriterBase(const OpBuilder &otherBuilder)
-      : OpBuilder(otherBuilder) {
-    setListener(this);
-  }
-  ~RewriterBase() override;
-
-  /// These are the callback methods that subclasses can choose to implement if
-  /// they would like to be notified about certain types of mutations.
-
-  /// Notify the rewriter that the specified operation is about to be replaced
-  /// with the set of values potentially produced by new operations. This is
-  /// called before the uses of the operation have been changed.
-  virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {}
-
-  /// This is called on an operation that a rewrite is removing, right before
-  /// the operation is deleted. At this point, the operation has zero uses.
-  virtual void notifyOperationRemoved(Operation *op) {}
-
-  /// Notify the rewriter that the pattern failed to match the given operation,
-  /// and provide a callback to populate a diagnostic with the reason why the
-  /// failure occurred. This method allows for derived rewriters to optionally
-  /// hook into the reason why a rewrite failed, and display it to users.
-  virtual LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) {
-    return failure();
-  }
+      : OpBuilder(otherBuilder) {}
+  virtual ~RewriterBase();
 
 private:
   void operator=(const RewriterBase &) = delete;

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c592f2db999a2..229dc016957c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -618,7 +618,8 @@ struct ConversionPatternRewriterImpl;
 /// This class implements a pattern rewriter for use with ConversionPatterns. It
 /// extends the base PatternRewriter and provides special conversion specific
 /// hooks.
-class ConversionPatternRewriter final : public PatternRewriter {
+class ConversionPatternRewriter final : public PatternRewriter,
+                                        public RewriterBase::Listener {
 public:
   explicit ConversionPatternRewriter(MLIRContext *ctx);
   ~ConversionPatternRewriter() override;
@@ -742,6 +743,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
   detail::ConversionPatternRewriterImpl &getImpl();
 
 private:
+  using OpBuilder::getListener;
+  using OpBuilder::setListener;
+
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index d07204d475b32..25cb61857a10e 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -555,7 +555,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     // Inside regular functions we use the blocking wait operation to wait for
     // the async object (token, value or group) to become available.
     if (!isInCoroutine) {
-      ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
+      ImplicitLocOpBuilder builder(loc, op, &rewriter);
       builder.create<RuntimeAwaitOp>(loc, operand);
 
       // Assert that the awaited operands is not in the error state.
@@ -574,7 +574,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
       CoroMachinery &coro = funcCoro->getSecond();
       Block *suspended = op->getBlock();
 
-      ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
+      ImplicitLocOpBuilder builder(loc, op, &rewriter);
       MLIRContext *ctx = op->getContext();
 
       // Save the coroutine state and resume on a runtime managed thread when

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 3ec037069c2c2..0b10bafb9f163 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -342,7 +342,7 @@ static bool hasTensorSemantics(Operation *op) {
 
 namespace {
 /// A rewriter that keeps track of extra information during bufferization.
-class BufferizationRewriter : public IRRewriter {
+class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 public:
   BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
                         DenseSet<Operation *> &toMemrefOps,
@@ -352,18 +352,18 @@ class BufferizationRewriter : public IRRewriter {
                         BufferizationStatistics *statistics)
       : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
         worklist(worklist), analysisState(options), opFilter(opFilter),
-        statistics(statistics) {}
+        statistics(statistics) {
+    setListener(this);
+  }
 
 protected:
   void notifyOperationRemoved(Operation *op) override {
-    IRRewriter::notifyOperationRemoved(op);
     erasedOps.insert(op);
     // Erase if present.
     toMemrefOps.erase(op);
   }
 
   void notifyOperationInserted(Operation *op) override {
-    IRRewriter::notifyOperationInserted(op);
     erasedOps.erase(op);
 
     // Gather statistics about allocs and deallocs.

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d36791fef23d1..8eab32b201a04 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -388,8 +388,6 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
 // OpBuilder
 //===----------------------------------------------------------------------===//
 
-OpBuilder::Listener::~Listener() = default;
-
 /// Insert the given operation at the current insertion point and return it.
 Operation *OpBuilder::insert(Operation *op) {
   if (block)

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 1ca86cdcba1cc..10baea61d9a4f 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -217,6 +217,10 @@ void PDLPatternModule::registerRewriteFunction(StringRef name,
 // RewriterBase
 //===----------------------------------------------------------------------===//
 
+bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
+  return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
+}
+
 RewriterBase::~RewriterBase() {
   // Out of line to provide a vtable anchor for the class.
 }
@@ -232,7 +236,8 @@ void RewriterBase::replaceOpWithIf(
          "incorrect number of values to replace operation");
 
   // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootReplaced(op, newValues);
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationReplaced(op, newValues);
 
   // Replace each use of the results when the functor is true.
   bool replacedAllUses = true;
@@ -260,13 +265,15 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
 /// the operation.
 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootReplaced(op, newValues);
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationReplaced(op, newValues);
 
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
   op->replaceAllUsesWith(newValues);
 
-  notifyOperationRemoved(op);
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationRemoved(op);
   op->erase();
 }
 
@@ -274,7 +281,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
 /// the given operation *must* be known to be dead.
 void RewriterBase::eraseOp(Operation *op) {
   assert(op->use_empty() && "expected 'op' to have no uses");
-  notifyOperationRemoved(op);
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationRemoved(op);
   op->erase();
 }
 

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b82fc580d2ebd..0d78362da7f09 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1495,7 +1495,10 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
 
 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(*this)) {}
+      impl(new detail::ConversionPatternRewriterImpl(*this)) {
+  setListener(this);
+}
+
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
 void ConversionPatternRewriter::replaceOpWithIf(

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 997bdc6a1c49f..adf8b5121ab9e 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,7 +39,8 @@ namespace {
 /// This abstract class manages the worklist and contains helper methods for
 /// rewriting ops on the worklist. Derived classes specify how ops are added
 /// to the worklist in the beginning.
-class GreedyPatternRewriteDriver : public PatternRewriter {
+class GreedyPatternRewriteDriver : public PatternRewriter,
+                                   public RewriterBase::Listener {
 protected:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
@@ -67,7 +68,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
   /// Notify the driver that the specified operation was replaced. Update the
   /// worklist as needed: New users are added enqueued.
-  void notifyRootReplaced(Operation *op, ValueRange replacement) override;
+  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
 
   /// Process ops until the worklist is empty or `config.maxNumRewrites` is
   /// reached. Return `true` if any IR was changed.
@@ -128,6 +129,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 
   // Apply a simple cost model based solely on pattern benefit.
   matcher.applyDefaultCostModel();
+
+  // Set up listener.
+  setListener(this);
 }
 
 bool GreedyPatternRewriteDriver::processWorklist() {
@@ -359,8 +363,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
     strictModeFilteredOps.erase(op);
 }
 
-void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
-                                                    ValueRange replacement) {
+void GreedyPatternRewriteDriver::notifyOperationReplaced(
+    Operation *op, ValueRange replacement) {
   LLVM_DEBUG({
     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
                        << ")\n";


        


More information about the Mlir-commits mailing list