[Mlir-commits] [mlir] [mlir] load dialect in parser for optional parameters (PR #96667)
Jeremy Kun
llvmlistbot at llvm.org
Wed Jul 3 10:35:17 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/96667
>From 26dc3a9793975e1249c085828014170f84dd86c6 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 25 Jun 2024 10:00:44 -0700
Subject: [PATCH 1/3] Support parser dialect loading for attribute
OptionalParameter
---
mlir/include/mlir/IR/DialectImplementation.h | 21 +++++++++++++++++++
mlir/test/IR/parser.mlir | 2 +-
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 4 ++--
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 11 ++++++++++
.../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 8 ++++---
5 files changed, 40 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 1e4f7f787a1ee..549b1226166bb 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -14,8 +14,23 @@
#ifndef MLIR_IR_DIALECTIMPLEMENTATION_H
#define MLIR_IR_DIALECTIMPLEMENTATION_H
+#include <type_traits>
#include "mlir/IR/OpImplementation.h"
+namespace {
+
+// reference https://stackoverflow.com/a/16000226
+template <typename T, typename = void>
+struct HasStaticDialectName : std::false_type {};
+
+template <typename T>
+struct HasStaticDialectName<
+ T, typename std::enable_if<
+ std::is_same<::llvm::StringLiteral, std::decay_t<decltype(T::dialectName)>>::value,
+ void>::type> : std::true_type {};
+
+} // namespace
+
namespace mlir {
//===----------------------------------------------------------------------===//
@@ -63,6 +78,9 @@ struct FieldParser<
AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
AttributeT>> {
static FailureOr<AttributeT> parse(AsmParser &parser) {
+ if constexpr (HasStaticDialectName<AttributeT>::value) {
+ parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
+ }
AttributeT value;
if (parser.parseCustomAttributeWithFallback(value))
return failure();
@@ -112,6 +130,9 @@ struct FieldParser<
std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
std::optional<AttributeT>>> {
static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) {
+ if constexpr (HasStaticDialectName<AttributeT>::value) {
+ parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
+ }
AttributeT attr;
OptionalParseResult result = parser.parseOptionalAttribute(attr);
if (result.has_value()) {
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index e0d7f2a5dd0cb..dfdfe069d1914 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1470,7 +1470,7 @@ test.format_optional_result_d_op : f80
// This is a testing that a non-qualified attribute in a custom format
// correctly preload the dialect before creating the attribute.
-#attr = #test.nested_polynomial<<1 + x**2>>
+#attr = #test.nested_polynomial<poly=<1 + x**2>>
// CHECK-lABLE: @parse_correctly
llvm.func @parse_correctly() {
test.containing_int_polynomial_attr #attr
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 9e25acf5f5ba4..471b78e489532 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -354,9 +354,9 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
- let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
+ let parameters = (ins OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">:$poly);
let assemblyFormat = [{
- `<` $poly `>`
+ `<` struct(params) `>`
}];
}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 55bc0714c20ec..f7e31b180c370 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -89,6 +89,8 @@ class DefGen {
void emitTopLevelDeclarations();
/// Emit the function that returns the type or attribute name.
void emitName();
+ /// Emit the dialect name s a static member variable.
+ void emitDialectName();
/// Emit attribute or type builders.
void emitBuilders();
/// Emit a verifier for the def.
@@ -184,6 +186,8 @@ DefGen::DefGen(const AttrOrTypeDef &def)
emitBuilders();
// Emit the type name.
emitName();
+ // Emit the dialect name.
+ emitDialectName();
// Emit the verifier.
if (storageCls && def.genVerifyDecl())
emitVerifier();
@@ -281,6 +285,13 @@ void DefGen::emitName() {
defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
}
+void DefGen::emitDialectName() {
+ std::string decl =
+ strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
+ def.getDialect().getName());
+ defCls.declare<ExtraClassDeclaration>(std::move(decl));
+}
+
void DefGen::emitBuilders() {
if (!def.skipDefaultBuilders()) {
emitDefaultBuilder();
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 50378feeb14ed..91c4ce35889b5 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -424,9 +424,11 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
Dialect dialect(dialectInit->getDef());
auto cppNamespace = dialect.getCppNamespace();
std::string name = dialect.getCppClassName();
- dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
- cppNamespace + "::" + name + ">();")
- .str();
+ if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
+ dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
+ cppNamespace + "::" + name + ">();")
+ .str();
+ }
}
}
}
>From 1d6cb25df368c68b1759561ec546046d9c925c01 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <kun.jeremy at gmail.com>
Date: Fri, 28 Jun 2024 15:19:15 -0700
Subject: [PATCH 2/3] Update mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index f7e31b180c370..5e072710eda52 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -89,7 +89,7 @@ class DefGen {
void emitTopLevelDeclarations();
/// Emit the function that returns the type or attribute name.
void emitName();
- /// Emit the dialect name s a static member variable.
+ /// Emit the dialect name as a static member variable.
void emitDialectName();
/// Emit attribute or type builders.
void emitBuilders();
>From 7b59a7d2e6506e44debd4a212a35b8ede4817042 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 3 Jul 2024 10:34:26 -0700
Subject: [PATCH 3/3] make second test case
---
mlir/test/IR/parser.mlir | 12 ------------
mlir/test/IR/parser_dialect_loading.mlir | 19 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 9 +++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 5 +++++
4 files changed, 33 insertions(+), 12 deletions(-)
create mode 100644 mlir/test/IR/parser_dialect_loading.mlir
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index dfdfe069d1914..cace1fefa43d6 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1464,15 +1464,3 @@ test.dialect_custom_format_fallback custom_format_fallback
// Check that an op with an optional result parses f80 as type.
// CHECK: test.format_optional_result_d_op : f80
test.format_optional_result_d_op : f80
-
-
-// -----
-
-// This is a testing that a non-qualified attribute in a custom format
-// correctly preload the dialect before creating the attribute.
-#attr = #test.nested_polynomial<poly=<1 + x**2>>
-// CHECK-lABLE: @parse_correctly
-llvm.func @parse_correctly() {
- test.containing_int_polynomial_attr #attr
- llvm.return
-}
diff --git a/mlir/test/IR/parser_dialect_loading.mlir b/mlir/test/IR/parser_dialect_loading.mlir
new file mode 100644
index 0000000000000..b9c2d30cf3c98
--- /dev/null
+++ b/mlir/test/IR/parser_dialect_loading.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt -allow-unregistered-dialect --split-input-file %s | FileCheck %s
+
+// This is a testing that a non-qualified attribute in a custom format
+// correctly preload the dialect before creating the attribute.
+#attr = #test.nested_polynomial<poly=<1 + x**2>>
+// CHECK-LABEL: @parse_correctly
+llvm.func @parse_correctly() {
+ test.containing_int_polynomial_attr #attr
+ llvm.return
+}
+
+// -----
+
+#attr2 = #test.nested_polynomial2<poly=<1 + x**2>>
+// CHECK-LABEL: @parse_correctly_2
+llvm.func @parse_correctly_2() {
+ test.containing_int_polynomial_attr2 #attr2
+ llvm.return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 471b78e489532..a0a1cd30ed8ae 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -354,10 +354,19 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
+ let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
+ let assemblyFormat = [{
+ `<` struct(params) `>`
+ }];
+}
+
+def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
+ let mnemonic = "nested_polynomial2";
let parameters = (ins OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">:$poly);
let assemblyFormat = [{
`<` struct(params) `>`
}];
}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c77f76def1f06..78e12d33503ab 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -237,6 +237,11 @@ def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
let assemblyFormat = "$attr attr-dict";
}
+def ContainingIntPolynomialAttr2Op : TEST_Op<"containing_int_polynomial_attr2"> {
+ let arguments = (ins NestedPolynomialAttr2:$attr);
+ let assemblyFormat = "$attr attr-dict";
+}
+
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
// This tests both matching and generating float elements attributes.
def UpdateFloatElementsAttr : Pat<
More information about the Mlir-commits
mailing list