[Mlir-commits] [mlir] 782c534 - [ODS] Implement a new 'hasCanonicalizeMethod' bit for cann patterns.
Chris Lattner
llvmlistbot at llvm.org
Tue Mar 23 13:45:58 PDT 2021
Author: Chris Lattner
Date: 2021-03-23T13:45:45-07:00
New Revision: 782c534117d1a600b054475c804ba2766e6e154c
URL: https://github.com/llvm/llvm-project/commit/782c534117d1a600b054475c804ba2766e6e154c
DIFF: https://github.com/llvm/llvm-project/commit/782c534117d1a600b054475c804ba2766e6e154c.diff
LOG: [ODS] Implement a new 'hasCanonicalizeMethod' bit for cann patterns.
This provides a simplified way to implement 'matchAndRewrite' style
canonicalization patterns for ops that don't need the full power of
RewritePatterns. Using this style, you can implement a static method
with a signature like:
```
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return success();
}
```
instead of dealing with defining RewritePattern subclasses. This also
adopts this for a few canonicalization patterns in the std dialect to
show how it works.
Differential Revision: https://reviews.llvm.org/D99143
Added:
Modified:
mlir/docs/Canonicalization.md
mlir/docs/OpDefinitions.md
mlir/docs/Tutorials/QuickstartRewrites.md
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md
index 4549369a4ccb..143c9edfc976 100644
--- a/mlir/docs/Canonicalization.md
+++ b/mlir/docs/Canonicalization.md
@@ -56,9 +56,9 @@ These transformations are applied to all levels of IR:
## Defining Canonicalizations
Two mechanisms are available with which to define canonicalizations;
-`getCanonicalizationPatterns` and `fold`.
+general `RewritePattern`s and the `fold` method.
-### Canonicalizing with `getCanonicalizationPatterns`
+### Canonicalizing with `RewritePattern`s
This mechanism allows for providing canonicalizations as a set of
`RewritePattern`s, either imperatively defined in C++ or declaratively as
@@ -67,13 +67,21 @@ infrastructure allows for expressing many
diff erent types of canonicalizations.
These transformations may be as simple as replacing a multiplication with a
shift, or even replacing a conditional branch with an unconditional one.
-In [ODS](OpDefinitions.md), an operation can set the `hasCanonicalizer` bit to
-generate a declaration for the `getCanonicalizationPatterns` method.
+In [ODS](OpDefinitions.md), an operation can set the `hasCanonicalizer` bit or
+the `hasCanonicalizeMethod` bit to generate a declaration for the
+`getCanonicalizationPatterns` method:
```tablegen
def MyOp : ... {
+ // I want to define a fully general set of patterns for this op.
let hasCanonicalizer = 1;
}
+
+def OtherOp : ... {
+ // A single "matchAndRewrite" style RewritePattern implemented as a method
+ // is good enough for me.
+ let hasCanonicalizeMethod = 1;
+}
```
Canonicalization patterns can then be provided in the source file:
@@ -83,12 +91,17 @@ void MyOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<...>(...);
}
+
+LogicalResult OtherOp::canonicalize(OtherOp op, PatternRewriter &rewriter) {
+ // patterns and rewrites go here.
+ return failure();
+}
```
See the [quickstart guide](Tutorials/QuickstartRewrites.md) for information on
defining operation rewrites.
-### Canonicalizing with `fold`
+### Canonicalizing with the `fold` method
The `fold` mechanism is an intentionally limited, but powerful mechanism that
allows for applying canonicalizations in many places throughout the compiler.
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 5f413582c698..e2203ff467fd 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -919,6 +919,13 @@ This boolean field indicate whether canonicalization patterns have been defined
for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should
be defined.
+### `hasCanonicalizeMethod`
+
+When this boolean field is set to `true`, it indicates that the op implements a
+`canonicalize` method for simple "matchAndRewrite" style canonicalization
+patterns. If `hasCanonicalizer` is 0, then an implementation of
+`::getCanonicalizationPatterns()` is implemented to call this function.
+
### `hasFolder`
This boolean field indicate whether general folding rules have been defined for
diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index d537050f1a32..54e67214a473 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -159,10 +159,61 @@ RewritePatternSet &patterns)` function that you can
use to collect all the generated patterns inside `patterns` and then use
`patterns` in any pass you would like.
-### C++ rewrite specification
+### Simple C++ `matchAndRewrite` style specifications
-In case patterns are not sufficient there is also the fully C++ way of
-expressing a rewrite:
+Many simple rewrites can be expressed with a `matchAndRewrite` style of
+pattern, e.g. when converting a multiply by a power of two into a shift. For
+these cases, the you can define the pattern as a simple function:
+
+```c++
+static LogicalResult
+convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
+ op, op->getResult(0).getType(), op->getOperand(0),
+ /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
+ return success();
+}
+
+void populateRewrites(RewritePatternSet &patternSet) {
+ // Add it to a pattern set.
+ patternSet.add(convertTFLeakyRelu);
+}
+```
+
+ODS provides a simple way to define a function-style canonicalization for your
+operation. In the TableGen definition of the op, specify
+`let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in
+your .cpp file:
+
+```c++
+// Example from the CIRCT project which has a variadic integer multiply.
+LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
+ auto inputs = op.inputs();
+ APInt value;
+
+ // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
+ if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) &&
+ value.isPowerOf2()) {
+ auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(),
+ value.exactLogBase2());
+ auto shlOp =
+ rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift);
+ rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(),
+ ArrayRef<Value>(shlOp));
+ return success();
+ }
+
+ return failure();
+}
+```
+
+However, you may want the full generality of canonicalization patterns, for that
+you can specify an arbitrary list of `RewritePattern`s.
+
+### Fully general C++ `RewritePattern` specifications
+
+In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
+you can also specify rewrites as a general set of `RewritePattern`s:
```c++
/// Multi-step rewrite using "match" and "rewrite". This allows for separating
@@ -202,19 +253,6 @@ In the C++ rewrite the static benefit of the rewrite pattern is specified at
construction. While in the pattern generator a simple heuristic is currently
employed based around the number of ops matched and replaced.
-In the case where you have a registered op and want to use a benefit of 1, you
-can even define the pattern as a C function:
-
-```c++
-static LogicalResult
-convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
- rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
- op, op->getResult(0).getType(), op->getOperand(0),
- /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
- return success();
-}
-```
-
The above rule did not capture the matching operands/attributes, but in general
the `match` function in a multi-step rewrite may populate and return a
`PatternState` (or class derived from one) to pass information extracted during
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 48d1834f899c..ce3087a176b3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -33,6 +33,7 @@ class AffineMap;
class Builder;
class FuncOp;
class OpBuilder;
+class PatternRewriter;
/// Return the list of Range (i.e. offset, size, stride). Each Range
/// entry contains either the dynamic value or a ConstantIndexOp constructed
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d551c74da8f9..84c152b351a0 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -339,7 +339,7 @@ def AssertOp : Std_Op<"assert"> {
// AssertOp is fully verified by its traits.
let verifier = ?;
- let hasCanonicalizer = 1;
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
@@ -500,7 +500,7 @@ def BranchOp : Std_Op<"br",
void eraseOperand(unsigned index);
}];
- let hasCanonicalizer = 1;
+ let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
@@ -629,7 +629,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [
}];
let verifier = ?;
- let hasCanonicalizer = 1;
+ let hasCanonicalizeMethod = 1;
let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)";
}
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index bdae05f7eea8..88f1427a1922 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2141,11 +2141,12 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
code verifier = ?;
// Whether this op has associated canonicalization patterns.
- // TODO: figure out a better way to write canonicalization patterns in
- // TableGen rules directly instead of using this marker and C++
- // implementations.
bit hasCanonicalizer = 0;
+ // Whether this op has a static "canonicalize" method to perform "match and
+ // rewrite patterns".
+ bit hasCanonicalizeMethod = 0;
+
// Whether this op has a folder.
bit hasFolder = 0;
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5f331eef241c..2f2f36e502d6 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -308,25 +308,13 @@ OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
// AssertOp
//===----------------------------------------------------------------------===//
-namespace {
-struct EraseRedundantAssertions : public OpRewritePattern<AssertOp> {
- using OpRewritePattern<AssertOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(AssertOp op,
- PatternRewriter &rewriter) const override {
- // Erase assertion if argument is constant true.
- if (matchPattern(op.arg(), m_One())) {
- rewriter.eraseOp(op);
- return success();
- }
- return failure();
+LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
+ // Erase assertion if argument is constant true.
+ if (matchPattern(op.arg(), m_One())) {
+ rewriter.eraseOp(op);
+ return success();
}
-};
-} // namespace
-
-void AssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
- MLIRContext *context) {
- patterns.add<EraseRedundantAssertions>(context);
+ return failure();
}
//===----------------------------------------------------------------------===//
@@ -498,26 +486,21 @@ static LogicalResult collapseBranch(Block *&successor,
return success();
}
-namespace {
/// Simplify a branch to a block that has a single predecessor. This effectively
/// merges the two blocks.
-struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
- using OpRewritePattern<BranchOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BranchOp op,
- PatternRewriter &rewriter) const override {
- // Check that the successor block has a single predecessor.
- Block *succ = op.getDest();
- Block *opParent = op->getBlock();
- if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
- return failure();
+static LogicalResult
+simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
+ // Check that the successor block has a single predecessor.
+ Block *succ = op.getDest();
+ Block *opParent = op->getBlock();
+ if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
+ return failure();
- // Merge the successor into the current block and erase the branch.
- rewriter.mergeBlocks(succ, opParent, op.getOperands());
- rewriter.eraseOp(op);
- return success();
- }
-};
+ // Merge the successor into the current block and erase the branch.
+ rewriter.mergeBlocks(succ, opParent, op.getOperands());
+ rewriter.eraseOp(op);
+ return success();
+}
/// br ^bb1
/// ^bb1
@@ -525,27 +508,27 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
///
/// -> br ^bbN(...)
///
-struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
- using OpRewritePattern<BranchOp>::OpRewritePattern;
+static LogicalResult simplifyPassThroughBr(BranchOp op,
+ PatternRewriter &rewriter) {
+ Block *dest = op.getDest();
+ ValueRange destOperands = op.getOperands();
+ SmallVector<Value, 4> destOperandStorage;
+
+ // Try to collapse the successor if it points somewhere other than this
+ // block.
+ if (dest == op->getBlock() ||
+ failed(collapseBranch(dest, destOperands, destOperandStorage)))
+ return failure();
- LogicalResult matchAndRewrite(BranchOp op,
- PatternRewriter &rewriter) const override {
- Block *dest = op.getDest();
- ValueRange destOperands = op.getOperands();
- SmallVector<Value, 4> destOperandStorage;
-
- // Try to collapse the successor if it points somewhere other than this
- // block.
- if (dest == op->getBlock() ||
- failed(collapseBranch(dest, destOperands, destOperandStorage)))
- return failure();
+ // Create a new branch with the collapsed successor.
+ rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
+ return success();
+}
- // Create a new branch with the collapsed successor.
- rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
- return success();
- }
-};
-} // end anonymous namespace.
+LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
+ return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
+ succeeded(simplifyPassThroughBr(op, rewriter)));
+}
Block *BranchOp::getDest() { return getSuccessor(); }
@@ -553,11 +536,6 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
-void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(context);
-}
-
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
@@ -608,31 +586,20 @@ FunctionType CallOp::getCalleeType() {
//===----------------------------------------------------------------------===//
// CallIndirectOp
//===----------------------------------------------------------------------===//
-namespace {
-/// Fold indirect calls that have a constant function as the callee operand.
-struct SimplifyIndirectCallWithKnownCallee
- : public OpRewritePattern<CallIndirectOp> {
- using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
- PatternRewriter &rewriter) const override {
- // Check that the callee is a constant callee.
- SymbolRefAttr calledFn;
- if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
- return failure();
- // Replace with a direct call.
- rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
- indirectCall.getResultTypes(),
- indirectCall.getArgOperands());
- return success();
- }
-};
-} // end anonymous namespace.
+/// Fold indirect calls that have a constant function as the callee operand.
+LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
+ PatternRewriter &rewriter) {
+ // Check that the callee is a constant callee.
+ SymbolRefAttr calledFn;
+ if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
+ return failure();
-void CallIndirectOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<SimplifyIndirectCallWithKnownCallee>(context);
+ // Replace with a direct call.
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
+ indirectCall.getResultTypes(),
+ indirectCall.getArgOperands());
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a1853362dce2..ac3bd168f9ca 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1674,15 +1674,40 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
}
void OpEmitter::genCanonicalizerDecls() {
- if (!def.getValueAsBit("hasCanonicalizer"))
+ bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
+ if (hasCanonicalizeMethod) {
+ // static LogicResult FooOp::
+ // canonicalize(FooOp op, PatternRewriter &rewriter);
+ SmallVector<OpMethodParameter, 2> paramList;
+ paramList.emplace_back(op.getCppClassName(), "op");
+ paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
+ opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
+ OpMethod::MP_StaticDeclaration,
+ std::move(paramList));
+ }
+
+ // We get a prototype for 'getCanonicalizationPatterns' if requested directly
+ // or if using a 'canonicalize' method.
+ bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
+ if (!hasCanonicalizeMethod && !hasCanonicalizer)
return;
+ // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
+ // method, but not implementing 'getCanonicalizationPatterns' manually.
+ bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
+
+ // Add a signature for getCanonicalizationPatterns if implemented by the
+ // dialect or if synthesized to call 'canonicalize'.
SmallVector<OpMethodParameter, 2> paramList;
paramList.emplace_back("::mlir::RewritePatternSet &", "results");
paramList.emplace_back("::mlir::MLIRContext *", "context");
- opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
- OpMethod::MP_StaticDeclaration,
- std::move(paramList));
+ auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
+ auto *method = opClass.addMethodAndPrune(
+ "void", "getCanonicalizationPatterns", kind, std::move(paramList));
+
+ // If synthesizing the method, fill it it.
+ if (hasBody)
+ method->body() << " results.add(canonicalize);\n";
}
void OpEmitter::genFolderDecls() {
More information about the Mlir-commits
mailing list