[Mlir-commits] [mlir] ec10f06 - [mlir][Pattern] Create a new IRRewriter class to enable sharing code with pattern rewrites

River Riddle llvmlistbot at llvm.org
Tue Feb 2 12:07:50 PST 2021


Author: River Riddle
Date: 2021-02-02T12:04:51-08:00
New Revision: ec10f0660963c77413f31a9b232b453f09425387

URL: https://github.com/llvm/llvm-project/commit/ec10f0660963c77413f31a9b232b453f09425387
DIFF: https://github.com/llvm/llvm-project/commit/ec10f0660963c77413f31a9b232b453f09425387.diff

LOG: [mlir][Pattern] Create a new IRRewriter class to enable sharing code with pattern rewrites

This revision adds two new classes, RewriterBase and IRRewriter. RewriterBase is a new shared base class between IRRewriter and PatternRewriter. PatternRewriter will continue to be the base class used to perform rewrites within a rewrite pattern. IRRewriter on the other hand, is a new class that allows for tracking IR rewrites from outside of a rewrite pattern. In this revision all of the old API from PatternRewriter is moved to RewriterBase, but the distinction between IRRewriter and PatternRewriter is kept on the chance that a necessary API divergence happens in the future.

Currently if you want to have some utility that transforms a piece of IR and share it between pattern and non-pattern code, you have to duplicate it. This revision enables the creation of utilities that can be invoked from rewrite patterns and normal transformation code:

```c++
void someSharedUtility(RewriterBase &rewriter, ...) {
  // Some interesting IR mutation here.
}

// Some RewritePattern
LogicalResult MyPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) {
  ...
  someSharedUtility(rewriter, ...);
  ...
}

// Some Pass
void MyPass::runOnOperation() {
  ...
  IRRewriter rewriter(...);
  someSharedUtility(rewriter, ...);
}
```

Differential Revision: https://reviews.llvm.org/D94638

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/IR/PatternMatch.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 04904dad32ef..8e1a5b98c318 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -414,20 +414,15 @@ class PDLPatternModule {
 };
 
 //===----------------------------------------------------------------------===//
-// PatternRewriter
+// RewriterBase
 //===----------------------------------------------------------------------===//
 
-/// This class coordinates the application of a pattern to the current function,
-/// providing a way to create operations and keep track of what gets deleted.
-///
-/// These class serves two purposes:
-///  1) it is the interface that patterns interact with to make mutations to the
-///     IR they are being applied to.
-///  2) It is a base class that clients of the PatternMatcher use when they want
-///     to apply patterns and observe their effects (e.g. to keep worklists or
-///     other data structures up to date).
-///
-class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
+/// This class coordinates the application of a rewrite on a set of IR,
+/// providing a way for clients to track mutations and create new operations.
+/// This class serves as a common API for IR mutation between pattern rewrites
+/// and non-pattern rewrites, and facilitates the development of shared
+/// IR transformation utilities.
+class RewriterBase : public OpBuilder, public OpBuilder::Listener {
 public:
   /// Move the blocks that belong to "region" before the given position in
   /// another region "parent". The two regions must be 
diff erent. The caller
@@ -452,10 +447,10 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
   /// `newValues` when the provided `functor` returns true for a specific use.
   /// The number of values in `newValues` is required to match the number of
   /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
-  /// the uses of `op` were replaced. Note that in some pattern rewriters, the
-  /// given 'functor' may be stored beyond the lifetime of the pattern being
-  /// applied. As such, the function should not capture by reference and instead
-  /// use value capture as necessary.
+  /// the uses of `op` were replaced. Note that in some rewriters, the given
+  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
+  /// As such, the function should not capture by reference and instead use
+  /// value capture as necessary.
   virtual void
   replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
                   llvm::unique_function<bool(OpOperand &) const> functor);
@@ -472,9 +467,9 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
   void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
                             bool *allUsesReplaced = nullptr);
 
-  /// This method performs the final replacement for a pattern, where the
-  /// results of the operation are updated to use the specified list of SSA
-  /// values.
+  /// This method replaces the results of the operation with the specified list
+  /// of values. The number of provided values must match the number of results
+  /// of the operation.
   virtual void replaceOp(Operation *op, ValueRange newValues);
 
   /// Replaces the result op with a new op that is created without verification.
@@ -534,10 +529,10 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
     finalizeRootUpdate(root);
   }
 
-  /// Notify the pattern rewriter that the pattern is failing to match the given
-  /// operation, and provide a callback to populate a diagnostic with the reason
-  /// why the failure occurred. This method allows for derived rewriters to
-  /// optionally hook into the reason why a pattern failed, and display it to
+  /// Used to notify the rewriter that the IR failed to be rewritten because of
+  /// a match failure, and provide a callback to populate a diagnostic with the
+  /// reason why the failure occurred. This method allows for derived rewriters
+  /// to optionally hook into the reason why a rewrite failed, and display it to
   /// users.
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
@@ -558,28 +553,29 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
 
 protected:
   /// Initialize the builder with this rewriter as the listener.
-  explicit PatternRewriter(MLIRContext *ctx)
-      : OpBuilder(ctx, /*listener=*/this) {}
-  ~PatternRewriter() override;
+  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
+  explicit RewriterBase(const OpBuilder &otherBuilder)
+      : OpBuilder(otherBuilder) {
+    setListener(this);
+  }
+  ~RewriterBase() override;
 
   /// These are the callback methods that subclasses can choose to implement if
   /// they would like to be notified about certain types of mutations.
 
-  /// Notify the pattern rewriter that the specified operation is about to be
-  /// replaced with another set of operations.  This is called before the uses
-  /// of the operation have been changed.
+  /// Notify the rewriter that the specified operation is about to be replaced
+  /// with another set of operations. This is called before the uses of the
+  /// operation have been changed.
   virtual void notifyRootReplaced(Operation *op) {}
 
-  /// This is called on an operation that a pattern match is removing, right
-  /// before the operation is deleted.  At this point, the operation has zero
-  /// uses.
+  /// This is called on an operation that a rewrite is removing, right before
+  /// the operation is deleted. At this point, the operation has zero uses.
   virtual void notifyOperationRemoved(Operation *op) {}
 
-  /// Notify the pattern rewriter that the pattern is failing to match the given
-  /// operation, and provide a callback to populate a diagnostic with the reason
-  /// why the failure occurred. This method allows for derived rewriters to
-  /// optionally hook into the reason why a pattern failed, and display it to
-  /// users.
+  /// Notify the rewriter that the pattern failed to match the given operation,
+  /// and provide a callback to populate a diagnostic with the reason why the
+  /// failure occurred. This method allows for derived rewriters to optionally
+  /// hook into the reason why a rewrite failed, and display it to users.
   virtual LogicalResult
   notifyMatchFailure(Operation *op,
                      function_ref<void(Diagnostic &)> reasonCallback) {
@@ -592,6 +588,35 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
   void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
 };
 
+//===----------------------------------------------------------------------===//
+// IRRewriter
+//===----------------------------------------------------------------------===//
+
+/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
+/// providing a way to keep track of the mutations made to the IR. This class
+/// should only be used in situations where another `RewriterBase` instance,
+/// such as a `PatternRewriter`, is not available.
+class IRRewriter : public RewriterBase {
+public:
+  explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
+  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
+};
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter
+//===----------------------------------------------------------------------===//
+
+/// A special type of `RewriterBase` that coordinates the application of a
+/// rewrite pattern on the current IR being matched, providing a way to keep
+/// track of any mutations made. This class should be used to perform all
+/// necessary IR mutations within a rewrite pattern, as the pattern driver may
+/// be tracking various state that would be invalidated when a mutation takes
+/// place.
+class PatternRewriter : public RewriterBase {
+public:
+  using RewriterBase::RewriterBase;
+};
+
 //===----------------------------------------------------------------------===//
 // OwningRewritePatternList
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 44f22ceeb3cf..90e89a536405 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -148,10 +148,10 @@ void PDLPatternModule::registerRewriteFunction(StringRef name,
 }
 
 //===----------------------------------------------------------------------===//
-// PatternRewriter
+// RewriterBase
 //===----------------------------------------------------------------------===//
 
-PatternRewriter::~PatternRewriter() {
+RewriterBase::~RewriterBase() {
   // Out of line to provide a vtable anchor for the class.
 }
 
@@ -159,7 +159,7 @@ PatternRewriter::~PatternRewriter() {
 /// `newValues` when the provided `functor` returns true for a specific use.
 /// The number of values in `newValues` is required to match the number of
 /// results of `op`.
-void PatternRewriter::replaceOpWithIf(
+void RewriterBase::replaceOpWithIf(
     Operation *op, ValueRange newValues, bool *allUsesReplaced,
     llvm::unique_function<bool(OpOperand &) const> functor) {
   assert(op->getNumResults() == newValues.size() &&
@@ -182,18 +182,17 @@ void PatternRewriter::replaceOpWithIf(
 /// `newValues` when a use is nested within the given `block`. The number of
 /// values in `newValues` is required to match the number of results of `op`.
 /// If all uses of this operation are replaced, the operation is erased.
-void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues,
-                                           Block *block,
-                                           bool *allUsesReplaced) {
+void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
+                                        Block *block, bool *allUsesReplaced) {
   replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
     return block->getParentOp()->isProperAncestor(use.getOwner());
   });
 }
 
-/// This method performs the final replacement for a pattern, where the
-/// results of the operation are updated to use the specified list of SSA
-/// values.
-void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
+/// This method replaces the results of the operation with the specified list of
+/// values. The number of provided values must match the number of results of
+/// the operation.
+void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   // Notify the rewriter subclass that we're about to replace this root.
   notifyRootReplaced(op);
 
@@ -207,13 +206,13 @@ void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
 
 /// This method erases an operation that is known to have no uses. The uses of
 /// the given operation *must* be known to be dead.
-void PatternRewriter::eraseOp(Operation *op) {
+void RewriterBase::eraseOp(Operation *op) {
   assert(op->use_empty() && "expected 'op' to have no uses");
   notifyOperationRemoved(op);
   op->erase();
 }
 
-void PatternRewriter::eraseBlock(Block *block) {
+void RewriterBase::eraseBlock(Block *block) {
   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
     assert(op.use_empty() && "expected 'op' to have no uses");
     eraseOp(&op);
@@ -225,8 +224,8 @@ void PatternRewriter::eraseBlock(Block *block) {
 /// 'source's predecessors must be empty or only contain 'dest`.
 /// 'argValues' is used to replace the block arguments of 'source' after
 /// merging.
-void PatternRewriter::mergeBlocks(Block *source, Block *dest,
-                                  ValueRange argValues) {
+void RewriterBase::mergeBlocks(Block *source, Block *dest,
+                               ValueRange argValues) {
   assert(llvm::all_of(source->getPredecessors(),
                       [dest](Block *succ) { return succ == dest; }) &&
          "expected 'source' to have no predecessors or only 'dest'");
@@ -246,8 +245,8 @@ void PatternRewriter::mergeBlocks(Block *source, Block *dest,
 
 // Merge the operations of block 'source' before the operation 'op'. Source
 // block should not have existing predecessors or successors.
-void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
-                                       ValueRange argValues) {
+void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
+                                    ValueRange argValues) {
   assert(source->hasNoPredecessors() &&
          "expected 'source' to have no predecessors");
   assert(source->hasNoSuccessors() &&
@@ -268,14 +267,14 @@ void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
 
 /// Split the operations starting at "before" (inclusive) out of the given
 /// block into a new block, and return it.
-Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
+Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
   return block->splitBlock(before);
 }
 
 /// 'op' and 'newOp' are known to have the same number of results, replace the
 /// uses of op with uses of newOp
-void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
-                                                      Operation *newOp) {
+void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
+                                                   Operation *newOp) {
   assert(op->getNumResults() == newOp->getNumResults() &&
          "replacement op doesn't match results of original op");
   if (op->getNumResults() == 1)
@@ -287,11 +286,11 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
 /// another region.  The two regions must be 
diff erent.  The caller is in
 /// charge to update create the operation transferring the control flow to the
 /// region and pass it the correct block arguments.
-void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
-                                         Region::iterator before) {
+void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
+                                      Region::iterator before) {
   parent.getBlocks().splice(before, region.getBlocks());
 }
-void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
+void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
 }
 
@@ -299,17 +298,16 @@ void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
 /// another region "parent". The two regions must be 
diff erent. The caller is
 /// responsible for creating or updating the operation transferring flow of
 /// control to the region and passing it the correct block arguments.
-void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
-                                        Region::iterator before,
-                                        BlockAndValueMapping &mapping) {
+void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
+                                     Region::iterator before,
+                                     BlockAndValueMapping &mapping) {
   region.cloneInto(&parent, before, mapping);
 }
-void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
-                                        Region::iterator before) {
+void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
+                                     Region::iterator before) {
   BlockAndValueMapping mapping;
   cloneRegionBefore(region, parent, before, mapping);
 }
-void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
+void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
-


        


More information about the Mlir-commits mailing list