[Mlir-commits] [mlir] [mlir] load dialect in parser for optional parameters (PR #96667)

Jeremy Kun llvmlistbot at llvm.org
Fri Jun 28 15:19:22 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/96667

>From 26dc3a9793975e1249c085828014170f84dd86c6 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 25 Jun 2024 10:00:44 -0700
Subject: [PATCH 1/2] Support parser dialect loading for attribute
 OptionalParameter

---
 mlir/include/mlir/IR/DialectImplementation.h  | 21 +++++++++++++++++++
 mlir/test/IR/parser.mlir                      |  2 +-
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  4 ++--
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp   | 11 ++++++++++
 .../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp |  8 ++++---
 5 files changed, 40 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 1e4f7f787a1ee..549b1226166bb 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -14,8 +14,23 @@
 #ifndef MLIR_IR_DIALECTIMPLEMENTATION_H
 #define MLIR_IR_DIALECTIMPLEMENTATION_H
 
+#include <type_traits>
 #include "mlir/IR/OpImplementation.h"
 
+namespace {
+
+// reference https://stackoverflow.com/a/16000226
+template <typename T, typename = void>
+struct HasStaticDialectName : std::false_type {};
+
+template <typename T>
+struct HasStaticDialectName<
+    T, typename std::enable_if<
+           std::is_same<::llvm::StringLiteral, std::decay_t<decltype(T::dialectName)>>::value,
+           void>::type> : std::true_type {};
+
+} // namespace
+
 namespace mlir {
 
 //===----------------------------------------------------------------------===//
@@ -63,6 +78,9 @@ struct FieldParser<
     AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
                                  AttributeT>> {
   static FailureOr<AttributeT> parse(AsmParser &parser) {
+    if constexpr (HasStaticDialectName<AttributeT>::value) {
+      parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
+    }
     AttributeT value;
     if (parser.parseCustomAttributeWithFallback(value))
       return failure();
@@ -112,6 +130,9 @@ struct FieldParser<
     std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
                      std::optional<AttributeT>>> {
   static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) {
+    if constexpr (HasStaticDialectName<AttributeT>::value) {
+      parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
+    }
     AttributeT attr;
     OptionalParseResult result = parser.parseOptionalAttribute(attr);
     if (result.has_value()) {
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index e0d7f2a5dd0cb..dfdfe069d1914 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1470,7 +1470,7 @@ test.format_optional_result_d_op : f80
 
 // This is a testing that a non-qualified attribute in a custom format
 // correctly preload the dialect before creating the attribute.
-#attr = #test.nested_polynomial<<1 + x**2>>
+#attr = #test.nested_polynomial<poly=<1 + x**2>>
 // CHECK-lABLE: @parse_correctly
 llvm.func @parse_correctly() {
   test.containing_int_polynomial_attr #attr
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 9e25acf5f5ba4..471b78e489532 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -354,9 +354,9 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
 
 def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
   let mnemonic = "nested_polynomial";
-  let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
+  let parameters = (ins OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">:$poly);
   let assemblyFormat = [{
-    `<` $poly `>`
+    `<` struct(params) `>`
   }];
 }
 
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 55bc0714c20ec..f7e31b180c370 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -89,6 +89,8 @@ class DefGen {
   void emitTopLevelDeclarations();
   /// Emit the function that returns the type or attribute name.
   void emitName();
+  /// Emit the dialect name s a static member variable.
+  void emitDialectName();
   /// Emit attribute or type builders.
   void emitBuilders();
   /// Emit a verifier for the def.
@@ -184,6 +186,8 @@ DefGen::DefGen(const AttrOrTypeDef &def)
     emitBuilders();
   // Emit the type name.
   emitName();
+  // Emit the dialect name.
+  emitDialectName();
   // Emit the verifier.
   if (storageCls && def.genVerifyDecl())
     emitVerifier();
@@ -281,6 +285,13 @@ void DefGen::emitName() {
   defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
 }
 
+void DefGen::emitDialectName() {
+  std::string decl =
+      strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
+             def.getDialect().getName());
+  defCls.declare<ExtraClassDeclaration>(std::move(decl));
+}
+
 void DefGen::emitBuilders() {
   if (!def.skipDefaultBuilders()) {
     emitDefaultBuilder();
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 50378feeb14ed..91c4ce35889b5 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -424,9 +424,11 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
         Dialect dialect(dialectInit->getDef());
         auto cppNamespace = dialect.getCppNamespace();
         std::string name = dialect.getCppClassName();
-        dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
-                          cppNamespace + "::" + name + ">();")
-                             .str();
+        if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
+          dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
+                            cppNamespace + "::" + name + ">();")
+                               .str();
+        }
       }
     }
   }

>From 1d6cb25df368c68b1759561ec546046d9c925c01 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <kun.jeremy at gmail.com>
Date: Fri, 28 Jun 2024 15:19:15 -0700
Subject: [PATCH 2/2] Update mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index f7e31b180c370..5e072710eda52 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -89,7 +89,7 @@ class DefGen {
   void emitTopLevelDeclarations();
   /// Emit the function that returns the type or attribute name.
   void emitName();
-  /// Emit the dialect name s a static member variable.
+  /// Emit the dialect name as a static member variable.
   void emitDialectName();
   /// Emit attribute or type builders.
   void emitBuilders();



More information about the Mlir-commits mailing list