[Mlir-commits] [mlir] ec10f06 - [mlir][Pattern] Create a new IRRewriter class to enable sharing code with pattern rewrites
River Riddle
llvmlistbot at llvm.org
Tue Feb 2 12:07:50 PST 2021
Author: River Riddle
Date: 2021-02-02T12:04:51-08:00
New Revision: ec10f0660963c77413f31a9b232b453f09425387
URL: https://github.com/llvm/llvm-project/commit/ec10f0660963c77413f31a9b232b453f09425387
DIFF: https://github.com/llvm/llvm-project/commit/ec10f0660963c77413f31a9b232b453f09425387.diff
LOG: [mlir][Pattern] Create a new IRRewriter class to enable sharing code with pattern rewrites
This revision adds two new classes, RewriterBase and IRRewriter. RewriterBase is a new shared base class between IRRewriter and PatternRewriter. PatternRewriter will continue to be the base class used to perform rewrites within a rewrite pattern. IRRewriter on the other hand, is a new class that allows for tracking IR rewrites from outside of a rewrite pattern. In this revision all of the old API from PatternRewriter is moved to RewriterBase, but the distinction between IRRewriter and PatternRewriter is kept on the chance that a necessary API divergence happens in the future.
Currently if you want to have some utility that transforms a piece of IR and share it between pattern and non-pattern code, you have to duplicate it. This revision enables the creation of utilities that can be invoked from rewrite patterns and normal transformation code:
```c++
void someSharedUtility(RewriterBase &rewriter, ...) {
// Some interesting IR mutation here.
}
// Some RewritePattern
LogicalResult MyPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) {
...
someSharedUtility(rewriter, ...);
...
}
// Some Pass
void MyPass::runOnOperation() {
...
IRRewriter rewriter(...);
someSharedUtility(rewriter, ...);
}
```
Differential Revision: https://reviews.llvm.org/D94638
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 04904dad32ef..8e1a5b98c318 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -414,20 +414,15 @@ class PDLPatternModule {
};
//===----------------------------------------------------------------------===//
-// PatternRewriter
+// RewriterBase
//===----------------------------------------------------------------------===//
-/// This class coordinates the application of a pattern to the current function,
-/// providing a way to create operations and keep track of what gets deleted.
-///
-/// These class serves two purposes:
-/// 1) it is the interface that patterns interact with to make mutations to the
-/// IR they are being applied to.
-/// 2) It is a base class that clients of the PatternMatcher use when they want
-/// to apply patterns and observe their effects (e.g. to keep worklists or
-/// other data structures up to date).
-///
-class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
+/// This class coordinates the application of a rewrite on a set of IR,
+/// providing a way for clients to track mutations and create new operations.
+/// This class serves as a common API for IR mutation between pattern rewrites
+/// and non-pattern rewrites, and facilitates the development of shared
+/// IR transformation utilities.
+class RewriterBase : public OpBuilder, public OpBuilder::Listener {
public:
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be
diff erent. The caller
@@ -452,10 +447,10 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
- /// the uses of `op` were replaced. Note that in some pattern rewriters, the
- /// given 'functor' may be stored beyond the lifetime of the pattern being
- /// applied. As such, the function should not capture by reference and instead
- /// use value capture as necessary.
+ /// the uses of `op` were replaced. Note that in some rewriters, the given
+ /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
+ /// As such, the function should not capture by reference and instead use
+ /// value capture as necessary.
virtual void
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor);
@@ -472,9 +467,9 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
bool *allUsesReplaced = nullptr);
- /// This method performs the final replacement for a pattern, where the
- /// results of the operation are updated to use the specified list of SSA
- /// values.
+ /// This method replaces the results of the operation with the specified list
+ /// of values. The number of provided values must match the number of results
+ /// of the operation.
virtual void replaceOp(Operation *op, ValueRange newValues);
/// Replaces the result op with a new op that is created without verification.
@@ -534,10 +529,10 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
finalizeRootUpdate(root);
}
- /// Notify the pattern rewriter that the pattern is failing to match the given
- /// operation, and provide a callback to populate a diagnostic with the reason
- /// why the failure occurred. This method allows for derived rewriters to
- /// optionally hook into the reason why a pattern failed, and display it to
+ /// Used to notify the rewriter that the IR failed to be rewritten because of
+ /// a match failure, and provide a callback to populate a diagnostic with the
+ /// reason why the failure occurred. This method allows for derived rewriters
+ /// to optionally hook into the reason why a rewrite failed, and display it to
/// users.
template <typename CallbackT>
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
@@ -558,28 +553,29 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
protected:
/// Initialize the builder with this rewriter as the listener.
- explicit PatternRewriter(MLIRContext *ctx)
- : OpBuilder(ctx, /*listener=*/this) {}
- ~PatternRewriter() override;
+ explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
+ explicit RewriterBase(const OpBuilder &otherBuilder)
+ : OpBuilder(otherBuilder) {
+ setListener(this);
+ }
+ ~RewriterBase() override;
/// These are the callback methods that subclasses can choose to implement if
/// they would like to be notified about certain types of mutations.
- /// Notify the pattern rewriter that the specified operation is about to be
- /// replaced with another set of operations. This is called before the uses
- /// of the operation have been changed.
+ /// Notify the rewriter that the specified operation is about to be replaced
+ /// with another set of operations. This is called before the uses of the
+ /// operation have been changed.
virtual void notifyRootReplaced(Operation *op) {}
- /// This is called on an operation that a pattern match is removing, right
- /// before the operation is deleted. At this point, the operation has zero
- /// uses.
+ /// This is called on an operation that a rewrite is removing, right before
+ /// the operation is deleted. At this point, the operation has zero uses.
virtual void notifyOperationRemoved(Operation *op) {}
- /// Notify the pattern rewriter that the pattern is failing to match the given
- /// operation, and provide a callback to populate a diagnostic with the reason
- /// why the failure occurred. This method allows for derived rewriters to
- /// optionally hook into the reason why a pattern failed, and display it to
- /// users.
+ /// Notify the rewriter that the pattern failed to match the given operation,
+ /// and provide a callback to populate a diagnostic with the reason why the
+ /// failure occurred. This method allows for derived rewriters to optionally
+ /// hook into the reason why a rewrite failed, and display it to users.
virtual LogicalResult
notifyMatchFailure(Operation *op,
function_ref<void(Diagnostic &)> reasonCallback) {
@@ -592,6 +588,35 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
};
+//===----------------------------------------------------------------------===//
+// IRRewriter
+//===----------------------------------------------------------------------===//
+
+/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
+/// providing a way to keep track of the mutations made to the IR. This class
+/// should only be used in situations where another `RewriterBase` instance,
+/// such as a `PatternRewriter`, is not available.
+class IRRewriter : public RewriterBase {
+public:
+ explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
+ explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
+};
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter
+//===----------------------------------------------------------------------===//
+
+/// A special type of `RewriterBase` that coordinates the application of a
+/// rewrite pattern on the current IR being matched, providing a way to keep
+/// track of any mutations made. This class should be used to perform all
+/// necessary IR mutations within a rewrite pattern, as the pattern driver may
+/// be tracking various state that would be invalidated when a mutation takes
+/// place.
+class PatternRewriter : public RewriterBase {
+public:
+ using RewriterBase::RewriterBase;
+};
+
//===----------------------------------------------------------------------===//
// OwningRewritePatternList
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 44f22ceeb3cf..90e89a536405 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -148,10 +148,10 @@ void PDLPatternModule::registerRewriteFunction(StringRef name,
}
//===----------------------------------------------------------------------===//
-// PatternRewriter
+// RewriterBase
//===----------------------------------------------------------------------===//
-PatternRewriter::~PatternRewriter() {
+RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}
@@ -159,7 +159,7 @@ PatternRewriter::~PatternRewriter() {
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`.
-void PatternRewriter::replaceOpWithIf(
+void RewriterBase::replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) {
assert(op->getNumResults() == newValues.size() &&
@@ -182,18 +182,17 @@ void PatternRewriter::replaceOpWithIf(
/// `newValues` when a use is nested within the given `block`. The number of
/// values in `newValues` is required to match the number of results of `op`.
/// If all uses of this operation are replaced, the operation is erased.
-void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues,
- Block *block,
- bool *allUsesReplaced) {
+void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
+ Block *block, bool *allUsesReplaced) {
replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
return block->getParentOp()->isProperAncestor(use.getOwner());
});
}
-/// This method performs the final replacement for a pattern, where the
-/// results of the operation are updated to use the specified list of SSA
-/// values.
-void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
+/// This method replaces the results of the operation with the specified list of
+/// values. The number of provided values must match the number of results of
+/// the operation.
+void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op);
@@ -207,13 +206,13 @@ void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
/// This method erases an operation that is known to have no uses. The uses of
/// the given operation *must* be known to be dead.
-void PatternRewriter::eraseOp(Operation *op) {
+void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
notifyOperationRemoved(op);
op->erase();
}
-void PatternRewriter::eraseBlock(Block *block) {
+void RewriterBase::eraseBlock(Block *block) {
for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
assert(op.use_empty() && "expected 'op' to have no uses");
eraseOp(&op);
@@ -225,8 +224,8 @@ void PatternRewriter::eraseBlock(Block *block) {
/// 'source's predecessors must be empty or only contain 'dest`.
/// 'argValues' is used to replace the block arguments of 'source' after
/// merging.
-void PatternRewriter::mergeBlocks(Block *source, Block *dest,
- ValueRange argValues) {
+void RewriterBase::mergeBlocks(Block *source, Block *dest,
+ ValueRange argValues) {
assert(llvm::all_of(source->getPredecessors(),
[dest](Block *succ) { return succ == dest; }) &&
"expected 'source' to have no predecessors or only 'dest'");
@@ -246,8 +245,8 @@ void PatternRewriter::mergeBlocks(Block *source, Block *dest,
// Merge the operations of block 'source' before the operation 'op'. Source
// block should not have existing predecessors or successors.
-void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
- ValueRange argValues) {
+void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
+ ValueRange argValues) {
assert(source->hasNoPredecessors() &&
"expected 'source' to have no predecessors");
assert(source->hasNoSuccessors() &&
@@ -268,14 +267,14 @@ void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
-Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
+Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
return block->splitBlock(before);
}
/// 'op' and 'newOp' are known to have the same number of results, replace the
/// uses of op with uses of newOp
-void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
- Operation *newOp) {
+void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
+ Operation *newOp) {
assert(op->getNumResults() == newOp->getNumResults() &&
"replacement op doesn't match results of original op");
if (op->getNumResults() == 1)
@@ -287,11 +286,11 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
/// another region. The two regions must be
diff erent. The caller is in
/// charge to update create the operation transferring the control flow to the
/// region and pass it the correct block arguments.
-void PatternRewriter::inlineRegionBefore(Region ®ion, Region &parent,
- Region::iterator before) {
+void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before) {
parent.getBlocks().splice(before, region.getBlocks());
}
-void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) {
+void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
}
@@ -299,17 +298,16 @@ void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) {
/// another region "parent". The two regions must be
diff erent. The caller is
/// responsible for creating or updating the operation transferring flow of
/// control to the region and passing it the correct block arguments.
-void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before,
- BlockAndValueMapping &mapping) {
+void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before,
+ BlockAndValueMapping &mapping) {
region.cloneInto(&parent, before, mapping);
}
-void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before) {
+void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before) {
BlockAndValueMapping mapping;
cloneRegionBefore(region, parent, before, mapping);
}
-void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) {
+void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
-
More information about the Mlir-commits
mailing list