[Mlir-commits] [mlir] 782c534 - [ODS] Implement a new 'hasCanonicalizeMethod' bit for cann patterns.

Chris Lattner llvmlistbot at llvm.org
Tue Mar 23 13:45:58 PDT 2021


Author: Chris Lattner
Date: 2021-03-23T13:45:45-07:00
New Revision: 782c534117d1a600b054475c804ba2766e6e154c

URL: https://github.com/llvm/llvm-project/commit/782c534117d1a600b054475c804ba2766e6e154c
DIFF: https://github.com/llvm/llvm-project/commit/782c534117d1a600b054475c804ba2766e6e154c.diff

LOG: [ODS] Implement a new 'hasCanonicalizeMethod' bit for cann patterns.

This provides a simplified way to implement 'matchAndRewrite' style
canonicalization patterns for ops that don't need the full power of
RewritePatterns.  Using this style, you can implement a static method
with a signature like:

```
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
  return success();
}
```

instead of dealing with defining RewritePattern subclasses.  This also
adopts this for a few canonicalization patterns in the std dialect to
show how it works.

Differential Revision: https://reviews.llvm.org/D99143

Added: 
    

Modified: 
    mlir/docs/Canonicalization.md
    mlir/docs/OpDefinitions.md
    mlir/docs/Tutorials/QuickstartRewrites.md
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md
index 4549369a4ccb..143c9edfc976 100644
--- a/mlir/docs/Canonicalization.md
+++ b/mlir/docs/Canonicalization.md
@@ -56,9 +56,9 @@ These transformations are applied to all levels of IR:
 ## Defining Canonicalizations
 
 Two mechanisms are available with which to define canonicalizations;
-`getCanonicalizationPatterns` and `fold`.
+general `RewritePattern`s and the `fold` method.
 
-### Canonicalizing with `getCanonicalizationPatterns`
+### Canonicalizing with `RewritePattern`s
 
 This mechanism allows for providing canonicalizations as a set of
 `RewritePattern`s, either imperatively defined in C++ or declaratively as
@@ -67,13 +67,21 @@ infrastructure allows for expressing many 
diff erent types of canonicalizations.
 These transformations may be as simple as replacing a multiplication with a
 shift, or even replacing a conditional branch with an unconditional one.
 
-In [ODS](OpDefinitions.md), an operation can set the `hasCanonicalizer` bit to
-generate a declaration for the `getCanonicalizationPatterns` method.
+In [ODS](OpDefinitions.md), an operation can set the `hasCanonicalizer` bit or
+the `hasCanonicalizeMethod` bit to generate a declaration for the
+`getCanonicalizationPatterns` method:
 
 ```tablegen
 def MyOp : ... {
+  // I want to define a fully general set of patterns for this op.
   let hasCanonicalizer = 1;
 }
+
+def OtherOp : ... {
+  // A single "matchAndRewrite" style RewritePattern implemented as a method
+  // is good enough for me.
+  let hasCanonicalizeMethod = 1;
+}
 ```
 
 Canonicalization patterns can then be provided in the source file:
@@ -83,12 +91,17 @@ void MyOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                        MLIRContext *context) {
   patterns.add<...>(...);
 }
+
+LogicalResult OtherOp::canonicalize(OtherOp op, PatternRewriter &rewriter) {
+  // patterns and rewrites go here.
+  return failure();
+}
 ```
 
 See the [quickstart guide](Tutorials/QuickstartRewrites.md) for information on
 defining operation rewrites.
 
-### Canonicalizing with `fold`
+### Canonicalizing with the `fold` method
 
 The `fold` mechanism is an intentionally limited, but powerful mechanism that
 allows for applying canonicalizations in many places throughout the compiler.

diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 5f413582c698..e2203ff467fd 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -919,6 +919,13 @@ This boolean field indicate whether canonicalization patterns have been defined
 for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should
 be defined.
 
+### `hasCanonicalizeMethod`
+
+When this boolean field is set to `true`, it indicates that the op implements a
+`canonicalize` method for simple "matchAndRewrite" style canonicalization
+patterns.  If `hasCanonicalizer` is 0, then an implementation of
+`::getCanonicalizationPatterns()` is implemented to call this function.
+
 ### `hasFolder`
 
 This boolean field indicate whether general folding rules have been defined for

diff  --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index d537050f1a32..54e67214a473 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -159,10 +159,61 @@ RewritePatternSet &patterns)` function that you can
 use to collect all the generated patterns inside `patterns` and then use
 `patterns` in any pass you would like.
 
-### C++ rewrite specification
+### Simple C++ `matchAndRewrite` style specifications
 
-In case patterns are not sufficient there is also the fully C++ way of
-expressing a rewrite:
+Many simple rewrites can be expressed with a `matchAndRewrite` style  of
+pattern, e.g. when converting a multiply by a power of two into a shift.  For
+these cases, the you can define the pattern as a simple function:
+
+```c++
+static LogicalResult
+convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
+      op, op->getResult(0).getType(), op->getOperand(0),
+      /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
+  return success();
+}
+
+void populateRewrites(RewritePatternSet &patternSet) {
+  // Add it to a pattern set.
+  patternSet.add(convertTFLeakyRelu);
+}
+```
+
+ODS provides a simple way to define a function-style canonicalization for your
+operation.  In the TableGen definition of the op, specify
+`let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in
+your .cpp file:
+
+```c++
+// Example from the CIRCT project which has a variadic integer multiply.
+LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
+  auto inputs = op.inputs();
+  APInt value;
+
+  // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
+  if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) &&
+      value.isPowerOf2()) {
+    auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(),
+                                                  value.exactLogBase2());
+    auto shlOp =
+        rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift);
+    rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(),
+                                       ArrayRef<Value>(shlOp));
+    return success();
+  }
+
+  return failure();
+}
+```
+
+However, you may want the full generality of canonicalization patterns, for that
+you can specify an arbitrary list of `RewritePattern`s.
+
+### Fully general C++ `RewritePattern` specifications
+
+In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
+you can also specify rewrites as a general set of `RewritePattern`s:
 
 ```c++
 /// Multi-step rewrite using "match" and "rewrite". This allows for separating
@@ -202,19 +253,6 @@ In the C++ rewrite the static benefit of the rewrite pattern is specified at
 construction. While in the pattern generator a simple heuristic is currently
 employed based around the number of ops matched and replaced.
 
-In the case where you have a registered op and want to use a benefit of 1, you
-can even define the pattern as a C function:
-
-```c++
-static LogicalResult
-convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
-  rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
-      op, op->getResult(0).getType(), op->getOperand(0),
-      /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
-  return success();
-}
-```
-
 The above rule did not capture the matching operands/attributes, but in general
 the `match` function in a multi-step rewrite may populate and return a
 `PatternState` (or class derived from one) to pass information extracted during

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 48d1834f899c..ce3087a176b3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -33,6 +33,7 @@ class AffineMap;
 class Builder;
 class FuncOp;
 class OpBuilder;
+class PatternRewriter;
 
 /// Return the list of Range (i.e. offset, size, stride). Each Range
 /// entry contains either the dynamic value or a ConstantIndexOp constructed

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d551c74da8f9..84c152b351a0 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -339,7 +339,7 @@ def AssertOp : Std_Op<"assert"> {
   // AssertOp is fully verified by its traits.
   let verifier = ?;
 
-  let hasCanonicalizer = 1;
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -500,7 +500,7 @@ def BranchOp : Std_Op<"br",
     void eraseOperand(unsigned index);
   }];
 
-  let hasCanonicalizer = 1;
+  let hasCanonicalizeMethod = 1;
   let assemblyFormat = [{
     $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
   }];
@@ -629,7 +629,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [
   }];
 
   let verifier = ?;
-  let hasCanonicalizer = 1;
+  let hasCanonicalizeMethod = 1;
 
   let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)";
 }

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index bdae05f7eea8..88f1427a1922 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2141,11 +2141,12 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   code verifier = ?;
 
   // Whether this op has associated canonicalization patterns.
-  // TODO: figure out a better way to write canonicalization patterns in
-  // TableGen rules directly instead of using this marker and C++
-  // implementations.
   bit hasCanonicalizer = 0;
 
+  // Whether this op has a static "canonicalize" method to perform "match and
+  // rewrite patterns".
+  bit hasCanonicalizeMethod = 0;
+
   // Whether this op has a folder.
   bit hasFolder = 0;
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5f331eef241c..2f2f36e502d6 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -308,25 +308,13 @@ OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
 // AssertOp
 //===----------------------------------------------------------------------===//
 
-namespace {
-struct EraseRedundantAssertions : public OpRewritePattern<AssertOp> {
-  using OpRewritePattern<AssertOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(AssertOp op,
-                                PatternRewriter &rewriter) const override {
-    // Erase assertion if argument is constant true.
-    if (matchPattern(op.arg(), m_One())) {
-      rewriter.eraseOp(op);
-      return success();
-    }
-    return failure();
+LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
+  // Erase assertion if argument is constant true.
+  if (matchPattern(op.arg(), m_One())) {
+    rewriter.eraseOp(op);
+    return success();
   }
-};
-} // namespace
-
-void AssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
-                                           MLIRContext *context) {
-  patterns.add<EraseRedundantAssertions>(context);
+  return failure();
 }
 
 //===----------------------------------------------------------------------===//
@@ -498,26 +486,21 @@ static LogicalResult collapseBranch(Block *&successor,
   return success();
 }
 
-namespace {
 /// Simplify a branch to a block that has a single predecessor. This effectively
 /// merges the two blocks.
-struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
-  using OpRewritePattern<BranchOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BranchOp op,
-                                PatternRewriter &rewriter) const override {
-    // Check that the successor block has a single predecessor.
-    Block *succ = op.getDest();
-    Block *opParent = op->getBlock();
-    if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
-      return failure();
+static LogicalResult
+simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
+  // Check that the successor block has a single predecessor.
+  Block *succ = op.getDest();
+  Block *opParent = op->getBlock();
+  if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
+    return failure();
 
-    // Merge the successor into the current block and erase the branch.
-    rewriter.mergeBlocks(succ, opParent, op.getOperands());
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
+  // Merge the successor into the current block and erase the branch.
+  rewriter.mergeBlocks(succ, opParent, op.getOperands());
+  rewriter.eraseOp(op);
+  return success();
+}
 
 ///   br ^bb1
 /// ^bb1
@@ -525,27 +508,27 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
 ///
 ///  -> br ^bbN(...)
 ///
-struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
-  using OpRewritePattern<BranchOp>::OpRewritePattern;
+static LogicalResult simplifyPassThroughBr(BranchOp op,
+                                           PatternRewriter &rewriter) {
+  Block *dest = op.getDest();
+  ValueRange destOperands = op.getOperands();
+  SmallVector<Value, 4> destOperandStorage;
+
+  // Try to collapse the successor if it points somewhere other than this
+  // block.
+  if (dest == op->getBlock() ||
+      failed(collapseBranch(dest, destOperands, destOperandStorage)))
+    return failure();
 
-  LogicalResult matchAndRewrite(BranchOp op,
-                                PatternRewriter &rewriter) const override {
-    Block *dest = op.getDest();
-    ValueRange destOperands = op.getOperands();
-    SmallVector<Value, 4> destOperandStorage;
-
-    // Try to collapse the successor if it points somewhere other than this
-    // block.
-    if (dest == op->getBlock() ||
-        failed(collapseBranch(dest, destOperands, destOperandStorage)))
-      return failure();
+  // Create a new branch with the collapsed successor.
+  rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
+  return success();
+}
 
-    // Create a new branch with the collapsed successor.
-    rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
-    return success();
-  }
-};
-} // end anonymous namespace.
+LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
+  return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
+                 succeeded(simplifyPassThroughBr(op, rewriter)));
+}
 
 Block *BranchOp::getDest() { return getSuccessor(); }
 
@@ -553,11 +536,6 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); }
 
 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
 
-void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                           MLIRContext *context) {
-  results.add<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(context);
-}
-
 Optional<MutableOperandRange>
 BranchOp::getMutableSuccessorOperands(unsigned index) {
   assert(index == 0 && "invalid successor index");
@@ -608,31 +586,20 @@ FunctionType CallOp::getCalleeType() {
 //===----------------------------------------------------------------------===//
 // CallIndirectOp
 //===----------------------------------------------------------------------===//
-namespace {
-/// Fold indirect calls that have a constant function as the callee operand.
-struct SimplifyIndirectCallWithKnownCallee
-    : public OpRewritePattern<CallIndirectOp> {
-  using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
-                                PatternRewriter &rewriter) const override {
-    // Check that the callee is a constant callee.
-    SymbolRefAttr calledFn;
-    if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
-      return failure();
 
-    // Replace with a direct call.
-    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
-                                        indirectCall.getResultTypes(),
-                                        indirectCall.getArgOperands());
-    return success();
-  }
-};
-} // end anonymous namespace.
+/// Fold indirect calls that have a constant function as the callee operand.
+LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
+                                           PatternRewriter &rewriter) {
+  // Check that the callee is a constant callee.
+  SymbolRefAttr calledFn;
+  if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
+    return failure();
 
-void CallIndirectOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                 MLIRContext *context) {
-  results.add<SimplifyIndirectCallWithKnownCallee>(context);
+  // Replace with a direct call.
+  rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
+                                      indirectCall.getResultTypes(),
+                                      indirectCall.getArgOperands());
+  return success();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a1853362dce2..ac3bd168f9ca 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1674,15 +1674,40 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
 }
 
 void OpEmitter::genCanonicalizerDecls() {
-  if (!def.getValueAsBit("hasCanonicalizer"))
+  bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
+  if (hasCanonicalizeMethod) {
+    // static LogicResult FooOp::
+    // canonicalize(FooOp op, PatternRewriter &rewriter);
+    SmallVector<OpMethodParameter, 2> paramList;
+    paramList.emplace_back(op.getCppClassName(), "op");
+    paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
+    opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
+                              OpMethod::MP_StaticDeclaration,
+                              std::move(paramList));
+  }
+
+  // We get a prototype for 'getCanonicalizationPatterns' if requested directly
+  // or if using a 'canonicalize' method.
+  bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
+  if (!hasCanonicalizeMethod && !hasCanonicalizer)
     return;
 
+  // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
+  // method, but not implementing 'getCanonicalizationPatterns' manually.
+  bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
+
+  // Add a signature for getCanonicalizationPatterns if implemented by the
+  // dialect or if synthesized to call 'canonicalize'.
   SmallVector<OpMethodParameter, 2> paramList;
   paramList.emplace_back("::mlir::RewritePatternSet &", "results");
   paramList.emplace_back("::mlir::MLIRContext *", "context");
-  opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
-                            OpMethod::MP_StaticDeclaration,
-                            std::move(paramList));
+  auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
+  auto *method = opClass.addMethodAndPrune(
+      "void", "getCanonicalizationPatterns", kind, std::move(paramList));
+
+  // If synthesizing the method, fill it it.
+  if (hasBody)
+    method->body() << "  results.add(canonicalize);\n";
 }
 
 void OpEmitter::genFolderDecls() {


        


More information about the Mlir-commits mailing list