[Mlir-commits] [mlir] 359164f - [mlir][OpBuilder] Remove the vtable from OpBuilder in favor of using the listener pattern

River Riddle llvmlistbot at llvm.org
Thu Apr 30 21:32:37 PDT 2020


Author: River Riddle
Date: 2020-04-30T21:29:25-07:00
New Revision: 359164f810282035b55e7b8bb7bbecb2ed0175d0

URL: https://github.com/llvm/llvm-project/commit/359164f810282035b55e7b8bb7bbecb2ed0175d0
DIFF: https://github.com/llvm/llvm-project/commit/359164f810282035b55e7b8bb7bbecb2ed0175d0.diff

LOG: [mlir][OpBuilder] Remove the vtable from OpBuilder in favor of using the listener pattern

The current OpBuilder has a set of virtual functions required by the fact that the PatternRewriter inherits from it for convenience. The PatternRewriter is required to know about IR mutations for correctness. This revision changes the relationship to be explicit by having users register a listener with the builder instead of using inheritance/vtables. This still requires that users properly transfer the listener when creating new builders, but has several benefits:

* More than one builder can be created during pattern rewrites(assuming that the listener is properly forwarded)
* OpBuilder no longer requires a vtable, and thus does not incur the cost when a listener isn't present.

Differential Revision: https://reviews.llvm.org/D79206

Added: 
    

Modified: 
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/IR/Builders.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 23c5ea2edb2a..a11da75f3ebf 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -171,49 +171,84 @@ class Builder {
 /// automatically inserted at an insertion point. The builder is copyable.
 class OpBuilder : public Builder {
 public:
+  struct Listener;
+
   /// Create a builder with the given context.
-  explicit OpBuilder(MLIRContext *ctx) : Builder(ctx) {}
+  explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr)
+      : Builder(ctx), listener(listener) {}
 
   /// Create a builder and set the insertion point to the start of the region.
-  explicit OpBuilder(Region *region) : Builder(region->getContext()) {
+  explicit OpBuilder(Region *region, Listener *listener = nullptr)
+      : OpBuilder(region->getContext(), listener) {
     if (!region->empty())
       setInsertionPoint(&region->front(), region->front().begin());
   }
-  explicit OpBuilder(Region &region) : OpBuilder(&region) {}
-
-  virtual ~OpBuilder();
+  explicit OpBuilder(Region &region, Listener *listener = nullptr)
+      : OpBuilder(&region, listener) {}
 
   /// Create a builder and set insertion point to the given operation, which
   /// will cause subsequent insertions to go right before it.
-  explicit OpBuilder(Operation *op) : Builder(op->getContext()) {
+  explicit OpBuilder(Operation *op, Listener *listener = nullptr)
+      : OpBuilder(op->getContext(), listener) {
     setInsertionPoint(op);
   }
 
-  OpBuilder(Block *block, Block::iterator insertPoint)
-      : OpBuilder(block->getParent()) {
+  OpBuilder(Block *block, Block::iterator insertPoint,
+            Listener *listener = nullptr)
+      : OpBuilder(block->getParent()->getContext(), listener) {
     setInsertionPoint(block, insertPoint);
   }
 
   /// Create a builder and set the insertion point to before the first operation
   /// in the block but still inside the block.
-  static OpBuilder atBlockBegin(Block *block) {
-    return OpBuilder(block, block->begin());
+  static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) {
+    return OpBuilder(block, block->begin(), listener);
   }
 
   /// Create a builder and set the insertion point to after the last operation
   /// in the block but still inside the block.
-  static OpBuilder atBlockEnd(Block *block) {
-    return OpBuilder(block, block->end());
+  static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) {
+    return OpBuilder(block, block->end(), listener);
   }
 
   /// Create a builder and set the insertion point to before the block
   /// terminator.
-  static OpBuilder atBlockTerminator(Block *block) {
+  static OpBuilder atBlockTerminator(Block *block,
+                                     Listener *listener = nullptr) {
     auto *terminator = block->getTerminator();
     assert(terminator != nullptr && "the block has no terminator");
-    return OpBuilder(block, terminator->getIterator());
+    return OpBuilder(block, Block::iterator(terminator), listener);
   }
 
+  //===--------------------------------------------------------------------===//
+  // Listeners
+  //===--------------------------------------------------------------------===//
+
+  /// This class represents a listener that may be used to hook into various
+  /// actions within an OpBuilder.
+  struct Listener {
+    virtual ~Listener();
+
+    /// Notification handler for when an operation is inserted into the builder.
+    /// `op` is the operation that was inserted.
+    virtual void notifyOperationInserted(Operation *op) {}
+
+    /// Notification handler for when a block is created using the builder.
+    /// `block` is the block that was created.
+    virtual void notifyBlockCreated(Block *block) {}
+  };
+
+  /// Sets the listener of this builder to the one provided.
+  void setListener(Listener *newListener) { listener = newListener; }
+
+  /// Returns the current listener of this builder, or nullptr if this builder
+  /// doesn't have a listener.
+  Listener *getListener() const { return listener; }
+
+  //===--------------------------------------------------------------------===//
+  // Insertion Point Management
+  //===--------------------------------------------------------------------===//
+
   /// This class represents a saved insertion point.
   class InsertPoint {
   public:
@@ -304,21 +339,29 @@ class OpBuilder : public Builder {
   /// Returns the current insertion point of the builder.
   Block::iterator getInsertionPoint() const { return insertPoint; }
 
-  /// Insert the given operation at the current insertion point and return it.
-  virtual Operation *insert(Operation *op);
+  /// Returns the current block of the builder.
+  Block *getBlock() const { return block; }
+
+  //===--------------------------------------------------------------------===//
+  // Block Creation
+  //===--------------------------------------------------------------------===//
 
   /// Add new block with 'argTypes' arguments and set the insertion point to the
   /// end of it. The block is inserted at the provided insertion point of
   /// 'parent'.
-  virtual Block *createBlock(Region *parent, Region::iterator insertPt = {},
-                             TypeRange argTypes = llvm::None);
+  Block *createBlock(Region *parent, Region::iterator insertPt = {},
+                     TypeRange argTypes = llvm::None);
 
   /// Add new block with 'argTypes' arguments and set the insertion point to the
   /// end of it. The block is placed before 'insertBefore'.
   Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
 
-  /// Returns the current block of the builder.
-  Block *getBlock() const { return block; }
+  //===--------------------------------------------------------------------===//
+  // Operation Creation
+  //===--------------------------------------------------------------------===//
+
+  /// Insert the given operation at the current insertion point and return it.
+  Operation *insert(Operation *op);
 
   /// Creates an operation given the fields represented as an OperationState.
   Operation *createOperation(const OperationState &state);
@@ -406,8 +449,13 @@ class OpBuilder : public Builder {
   }
 
 private:
+  /// The current block this builder is inserting into.
   Block *block = nullptr;
+  /// The insertion point within the block that this builder is inserting
+  /// before.
   Block::iterator insertPoint;
+  /// The optional listener for events of this builder.
+  Listener *listener;
 };
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9476bcf3bf84..0d125b3d8148 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -211,7 +211,7 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
 ///     to apply patterns and observe their effects (e.g. to keep worklists or
 ///     other data structures up to date).
 ///
-class PatternRewriter : public OpBuilder {
+class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
 public:
   /// Create operation of specific op type at the current insertion point
   /// without verifying to see if it is valid.
@@ -247,10 +247,6 @@ class PatternRewriter : public OpBuilder {
     return OpTy();
   }
 
-  /// This is implemented to insert the specified operation and serves as a
-  /// notification hook for rewriters that want to know about new operations.
-  virtual Operation *insert(Operation *op) = 0;
-
   /// Move the blocks that belong to "region" before the given position in
   /// another region "parent". The two regions must be 
diff erent. The caller
   /// is responsible for creating or updating the operation transferring flow
@@ -349,11 +345,13 @@ class PatternRewriter : public OpBuilder {
   }
 
 protected:
-  explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
-  virtual ~PatternRewriter();
+  /// Initialize the builder with this rewriter as the listener.
+  explicit PatternRewriter(MLIRContext *ctx)
+      : OpBuilder(ctx, /*listener=*/this) {}
+  ~PatternRewriter() override;
 
-  // These are the callback methods that subclasses can choose to implement if
-  // they would like to be notified about certain types of mutations.
+  /// These are the callback methods that subclasses can choose to implement if
+  /// they would like to be notified about certain types of mutations.
 
   /// Notify the pattern rewriter that the specified operation is about to be
   /// replaced with another set of operations.  This is called before the uses

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 4f1fafb191ac..31b5e04c9dbd 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -348,9 +348,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// implemented for dialect conversion.
   void eraseBlock(Block *block) override;
 
-  /// PatternRewriter hook for creating a new block with the given arguments.
-  Block *createBlock(Region *parent, Region::iterator insertPt = {},
-                     TypeRange argTypes = llvm::None) override;
+  /// PatternRewriter hook creating a new block.
+  void notifyBlockCreated(Block *block) override;
 
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
@@ -373,7 +372,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   using PatternRewriter::cloneRegionBefore;
 
   /// PatternRewriter hook for inserting a new operation.
-  Operation *insert(Operation *op) override;
+  void notifyOperationInserted(Operation *op) override;
 
   /// PatternRewriter hook for updating the root operation in-place.
   /// Note: These methods only track updates to the top-level operation itself,

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 22abeb5a364f..fcaf33aa98eb 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -330,15 +330,18 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
 }
 
 //===----------------------------------------------------------------------===//
-// OpBuilder.
+// OpBuilder
 //===----------------------------------------------------------------------===//
 
-OpBuilder::~OpBuilder() {}
+OpBuilder::Listener::~Listener() {}
 
 /// Insert the given operation at the current insertion point and return it.
 Operation *OpBuilder::insert(Operation *op) {
   if (block)
     block->getOperations().insert(insertPoint, op);
+
+  if (listener)
+    listener->notifyOperationInserted(op);
   return op;
 }
 
@@ -355,6 +358,9 @@ Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
   b->addArguments(argTypes);
   parent->getBlocks().insert(insertPt, b);
   setInsertionPointToEnd(b);
+
+  if (listener)
+    listener->notifyBlockCreated(b);
   return b;
 }
 

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 14a7084c50bf..00afd1c8ad92 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -954,12 +954,8 @@ Value ConversionPatternRewriter::getRemappedValue(Value key) {
 }
 
 /// PatternRewriter hook for creating a new block with the given arguments.
-Block *ConversionPatternRewriter::createBlock(Region *parent,
-                                              Region::iterator insertPtr,
-                                              TypeRange argTypes) {
-  Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes);
+void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
   impl->notifyCreatedBlock(block);
-  return block;
 }
 
 /// PatternRewriter hook for splitting a block into two parts.
@@ -1001,13 +997,12 @@ void ConversionPatternRewriter::cloneRegionBefore(
 }
 
 /// PatternRewriter hook for creating a new operation.
-Operation *ConversionPatternRewriter::insert(Operation *op) {
+void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
   LLVM_DEBUG({
     impl->logger.startLine()
         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
   });
   impl->createdOps.push_back(op);
-  return OpBuilder::insert(op);
 }
 
 /// PatternRewriter hook for updating the root operation in-place.

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2ebf1d6a47d7..f4022a4e5bde 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -77,10 +77,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 protected:
   // Implement the hook for inserting operations, and make sure that newly
   // inserted ops are added to the worklist for processing.
-  Operation *insert(Operation *op) override {
-    addToWorklist(op);
-    return OpBuilder::insert(op);
-  }
+  void notifyOperationInserted(Operation *op) override { addToWorklist(op); }
 
   // If an operation is about to be removed, make sure it is not in our
   // worklist anymore because we'd get dangling references to it.
@@ -266,9 +263,6 @@ class OpPatternRewriteDriver : public PatternRewriter {
 
   bool simplifyLocally(Operation *op, int maxIterations, bool &erased);
 
-  /// No additional action needed other than inserting the op.
-  Operation *insert(Operation *op) override { return OpBuilder::insert(op); }
-
   // These are hooks implemented for PatternRewriter.
 protected:
   /// If an operation is about to be removed, mark it so that we can let clients


        


More information about the Mlir-commits mailing list