[Mlir-commits] [mlir] 108ca7a - [mlir] Support dialect-wide canonicalization pattern registration
Matthias Springer
llvmlistbot at llvm.org
Thu May 27 01:35:36 PDT 2021
Author: Matthias Springer
Date: 2021-05-27T17:35:21+09:00
New Revision: 108ca7a7e73ca6d5f4c17a8291d0e94cd9f740d3
URL: https://github.com/llvm/llvm-project/commit/108ca7a7e73ca6d5f4c17a8291d0e94cd9f740d3
DIFF: https://github.com/llvm/llvm-project/commit/108ca7a7e73ca6d5f4c17a8291d0e94cd9f740d3.diff
LOG: [mlir] Support dialect-wide canonicalization pattern registration
* Add `hasCanonicalizer` option to Dialect.
* Initialize canonicalizer with dialect-wide canonicalization patterns.
* Add test case to TestDialect.
Dialect-wide canonicalization patterns are useful if a canonicalization pattern does not conceptually associate with any single operation, i.e., it should not be registered as part of an operation's `getCanonicalizationPatterns` function. E.g., this is the case for canonicalization patterns that match an op interface.
Differential Revision: https://reviews.llvm.org/D103226
Added:
Modified:
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Dialect.h
mlir/lib/TableGen/Dialect.cpp
mlir/lib/Transforms/Canonicalizer.cpp
mlir/test/Transforms/test-canonicalize.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-tblgen/DialectGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 26ae5f9d9033a..46782c4353d51 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -68,6 +68,13 @@ class Dialect {
/// These are represented with OpaqueType.
bool allowsUnknownTypes() const { return unknownTypesAllowed; }
+ /// Register dialect-wide canonicalization patterns. This method should only
+ /// be used to register canonicalization patterns that do not conceptually
+ /// belong to any single operation in the dialect. (In that case, use the op's
+ /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
+ /// be registered here.
+ virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
+
/// Registered hook to materialize a single constant operation from a given
/// attribute value with the desired resultant type. This method should use
/// the provided builder to create the operation without changing the
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 885ead981c774..952e3a36e457b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -275,6 +275,9 @@ class Dialect {
// If this dialect overrides the hook for op interface fallback.
bit hasOperationInterfaceFallback = 0;
+
+ // If this dialect overrides the hook for canonicalization patterns.
+ bit hasCanonicalizer = 0;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 35a9e7ba4c9b6..609bf4e2ec466 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -51,6 +51,9 @@ class Dialect {
// Returns the dialects extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
+ /// Returns true if this dialect has a canonicalizer.
+ bool hasCanonicalizer() const;
+
// Returns true if this dialect has a constant materializer.
bool hasConstantMaterializer() const;
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6b88f8a40bc1c..0cdd9d6d856e1 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -61,6 +61,10 @@ llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
+bool Dialect::hasCanonicalizer() const {
+ return def->getValueAsBit("hasCanonicalizer");
+}
+
bool Dialect::hasConstantMaterializer() const {
return def->getValueAsBit("hasConstantMaterializer");
}
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 1f27fc7203005..3ddd24b824ee3 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -35,6 +35,8 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
/// execution.
LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet owningPatterns(context);
+ for (auto *dialect : context->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(owningPatterns);
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
patterns = std::move(owningPatterns);
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 0b945136fdd65..d245a560bcfaa 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -104,4 +104,12 @@ func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
return %1, %2, %3, %4, %5 : index, index, index, index, index
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: test_dialect_canonicalizer
+func @test_dialect_canonicalizer() -> (i32) {
+ %0 = "test.dialect_canonicalizable"() : () -> (i32)
+ // CHECK: %[[CST:.*]] = constant 42 : i32
+ // CHECK: return %[[CST]]
+ return %0 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 79e43eeba0992..8e9e0f2645229 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -287,6 +287,23 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
return targetOperandsMutable();
}
+//===----------------------------------------------------------------------===//
+// TestDialectCanonicalizerOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
+ PatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(42));
+ return success();
+}
+
+void TestDialect::getCanonicalizationPatterns(
+ RewritePatternSet &results) const {
+ results.add(&dialectCanonicalizationPattern);
+}
+
//===----------------------------------------------------------------------===//
// TestFoldToCallOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 795a5af35babe..54ea3bff9bfcd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -25,6 +25,7 @@ include "TestInterfaces.td"
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "::mlir::test";
+ let hasCanonicalizer = 1;
let hasConstantMaterializer = 1;
let hasOperationAttrVerify = 1;
let hasRegionArgAttrVerify = 1;
@@ -966,6 +967,11 @@ def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
let hasFolder = 1;
}
+def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
+ let arguments = (ins);
+ let results = (outs I32);
+}
+
//===----------------------------------------------------------------------===//
// Test Patterns (Symbol Binding)
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 350cc6f18f9cb..dbbc32b3b0f4e 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -107,6 +107,13 @@ static const char *const typeParserDecl = R"(
::mlir::DialectAsmPrinter &os) const override;
)";
+/// The code block for the canonicalization pattern registration hook.
+static const char *const canonicalizerDecl = R"(
+ /// Register canonicalization patterns.
+ void getCanonicalizationPatterns(
+ ::mlir::RewritePatternSet &results) const override;
+)";
+
/// The code block for the constant materializer hook.
static const char *const constantMaterializerDecl = R"(
/// Materialize a single constant operation from a given attribute value with
@@ -180,6 +187,8 @@ static void emitDialectDecl(Dialect &dialect,
os << typeParserDecl;
// Add the decls for the various features of the dialect.
+ if (dialect.hasCanonicalizer())
+ os << canonicalizerDecl;
if (dialect.hasConstantMaterializer())
os << constantMaterializerDecl;
if (dialect.hasOperationAttrVerify())
More information about the Mlir-commits
mailing list