[Mlir-commits] [mlir] 108ca7a - [mlir] Support dialect-wide canonicalization pattern registration

Matthias Springer llvmlistbot at llvm.org
Thu May 27 01:35:36 PDT 2021


Author: Matthias Springer
Date: 2021-05-27T17:35:21+09:00
New Revision: 108ca7a7e73ca6d5f4c17a8291d0e94cd9f740d3

URL: https://github.com/llvm/llvm-project/commit/108ca7a7e73ca6d5f4c17a8291d0e94cd9f740d3
DIFF: https://github.com/llvm/llvm-project/commit/108ca7a7e73ca6d5f4c17a8291d0e94cd9f740d3.diff

LOG: [mlir] Support dialect-wide canonicalization pattern registration

* Add `hasCanonicalizer` option to Dialect.
* Initialize canonicalizer with dialect-wide canonicalization patterns.
* Add test case to TestDialect.

Dialect-wide canonicalization patterns are useful if a canonicalization pattern does not conceptually associate with any single operation, i.e., it should not be registered as part of an operation's `getCanonicalizationPatterns` function. E.g., this is the case for canonicalization patterns that match an op interface.

Differential Revision: https://reviews.llvm.org/D103226

Added: 
    

Modified: 
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Dialect.h
    mlir/lib/TableGen/Dialect.cpp
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/test/Transforms/test-canonicalize.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/DialectGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 26ae5f9d9033a..46782c4353d51 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -68,6 +68,13 @@ class Dialect {
   /// These are represented with OpaqueType.
   bool allowsUnknownTypes() const { return unknownTypesAllowed; }
 
+  /// Register dialect-wide canonicalization patterns. This method should only
+  /// be used to register canonicalization patterns that do not conceptually
+  /// belong to any single operation in the dialect. (In that case, use the op's
+  /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
+  /// be registered here.
+  virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
+
   /// Registered hook to materialize a single constant operation from a given
   /// attribute value with the desired resultant type. This method should use
   /// the provided builder to create the operation without changing the

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 885ead981c774..952e3a36e457b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -275,6 +275,9 @@ class Dialect {
 
   // If this dialect overrides the hook for op interface fallback.
   bit hasOperationInterfaceFallback = 0;
+
+  // If this dialect overrides the hook for canonicalization patterns.
+  bit hasCanonicalizer = 0;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 35a9e7ba4c9b6..609bf4e2ec466 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -51,6 +51,9 @@ class Dialect {
   // Returns the dialects extra class declaration code.
   llvm::Optional<StringRef> getExtraClassDeclaration() const;
 
+  /// Returns true if this dialect has a canonicalizer.
+  bool hasCanonicalizer() const;
+
   // Returns true if this dialect has a constant materializer.
   bool hasConstantMaterializer() const;
 

diff  --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6b88f8a40bc1c..0cdd9d6d856e1 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -61,6 +61,10 @@ llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
   return value.empty() ? llvm::Optional<StringRef>() : value;
 }
 
+bool Dialect::hasCanonicalizer() const {
+  return def->getValueAsBit("hasCanonicalizer");
+}
+
 bool Dialect::hasConstantMaterializer() const {
   return def->getValueAsBit("hasConstantMaterializer");
 }

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 1f27fc7203005..3ddd24b824ee3 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -35,6 +35,8 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
     RewritePatternSet owningPatterns(context);
+    for (auto *dialect : context->getLoadedDialects())
+      dialect->getCanonicalizationPatterns(owningPatterns);
     for (auto *op : context->getRegisteredOperations())
       op->getCanonicalizationPatterns(owningPatterns, context);
     patterns = std::move(owningPatterns);

diff  --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 0b945136fdd65..d245a560bcfaa 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -104,4 +104,12 @@ func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
   // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
   // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
   return %1, %2, %3, %4, %5 : index, index, index, index, index
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: test_dialect_canonicalizer
+func @test_dialect_canonicalizer() -> (i32) {
+  %0 = "test.dialect_canonicalizable"() : () -> (i32)
+  // CHECK: %[[CST:.*]] = constant 42 : i32
+  // CHECK: return %[[CST]]
+  return %0 : i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 79e43eeba0992..8e9e0f2645229 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -287,6 +287,23 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
   return targetOperandsMutable();
 }
 
+//===----------------------------------------------------------------------===//
+// TestDialectCanonicalizerOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
+                               PatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
+                                          rewriter.getI32IntegerAttr(42));
+  return success();
+}
+
+void TestDialect::getCanonicalizationPatterns(
+    RewritePatternSet &results) const {
+  results.add(&dialectCanonicalizationPattern);
+}
+
 //===----------------------------------------------------------------------===//
 // TestFoldToCallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 795a5af35babe..54ea3bff9bfcd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -25,6 +25,7 @@ include "TestInterfaces.td"
 def Test_Dialect : Dialect {
   let name = "test";
   let cppNamespace = "::mlir::test";
+  let hasCanonicalizer = 1;
   let hasConstantMaterializer = 1;
   let hasOperationAttrVerify = 1;
   let hasRegionArgAttrVerify = 1;
@@ -966,6 +967,11 @@ def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
   let hasFolder = 1;
 }
 
+def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
+  let arguments = (ins);
+  let results = (outs I32);
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Symbol Binding)
 

diff  --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 350cc6f18f9cb..dbbc32b3b0f4e 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -107,6 +107,13 @@ static const char *const typeParserDecl = R"(
                  ::mlir::DialectAsmPrinter &os) const override;
 )";
 
+/// The code block for the canonicalization pattern registration hook.
+static const char *const canonicalizerDecl = R"(
+  /// Register canonicalization patterns.
+  void getCanonicalizationPatterns(
+      ::mlir::RewritePatternSet &results) const override;
+)";
+
 /// The code block for the constant materializer hook.
 static const char *const constantMaterializerDecl = R"(
   /// Materialize a single constant operation from a given attribute value with
@@ -180,6 +187,8 @@ static void emitDialectDecl(Dialect &dialect,
     os << typeParserDecl;
 
   // Add the decls for the various features of the dialect.
+  if (dialect.hasCanonicalizer())
+    os << canonicalizerDecl;
   if (dialect.hasConstantMaterializer())
     os << constantMaterializerDecl;
   if (dialect.hasOperationAttrVerify())


        


More information about the Mlir-commits mailing list