[Mlir-commits] [mlir] 07c157a - [mlir] load dialect in parser for optional parameters (#96667)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 7 09:44:10 PDT 2024


Author: Jeremy Kun
Date: 2024-07-07T09:44:07-07:00
New Revision: 07c157a43534744bff8b9cf03a5ec8d19717ba72

URL: https://github.com/llvm/llvm-project/commit/07c157a43534744bff8b9cf03a5ec8d19717ba72
DIFF: https://github.com/llvm/llvm-project/commit/07c157a43534744bff8b9cf03a5ec8d19717ba72.diff

LOG: [mlir] load dialect in parser for optional parameters (#96667)

https://github.com/llvm/llvm-project/pull/96242 fixed an issue where the
auto-generated parsers were not loading dialects whose namespaces are
not present in the textual IR. This required the attribute parameter to
be a tablegen def with its dialect information attached.

This fails when using parameter wrapper classes like
`OptionalParameter`. This came up because `RingAttr` uses
`OptionalParameter` for its second and third attributes.
`OptionalParameter` takes as input the C++ type as a string instead of
the tablegen def, and so it doesn't have a dialect member value to
trigger the fix from https://github.com/llvm/llvm-project/pull/96242.
The docs on this topic say the appropriate solution as overloading
`FieldParser` for a particular type.

This PR updates `FieldParser` for generic attributes to load the dialect
on demand. This requires `mlir-tblgen` to emit a `dialectName` static
field on the generated attribute class, and check for it with template
metaprogramming, since not all attribute types go through `mlir-tblgen`.

---------

Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>

Added: 
    mlir/test/IR/parser_dialect_loading.mlir

Modified: 
    mlir/include/mlir/IR/DialectImplementation.h
    mlir/test/IR/parser.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 1e4f7f787a1ee..303564bf66470 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -15,6 +15,22 @@
 #define MLIR_IR_DIALECTIMPLEMENTATION_H
 
 #include "mlir/IR/OpImplementation.h"
+#include <type_traits>
+
+namespace {
+
+// reference https://stackoverflow.com/a/16000226
+template <typename T, typename = void>
+struct HasStaticDialectName : std::false_type {};
+
+template <typename T>
+struct HasStaticDialectName<
+    T, typename std::enable_if<
+           std::is_same<::llvm::StringLiteral,
+                        std::decay_t<decltype(T::dialectName)>>::value,
+           void>::type> : std::true_type {};
+
+} // namespace
 
 namespace mlir {
 
@@ -63,6 +79,9 @@ struct FieldParser<
     AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
                                  AttributeT>> {
   static FailureOr<AttributeT> parse(AsmParser &parser) {
+    if constexpr (HasStaticDialectName<AttributeT>::value) {
+      parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
+    }
     AttributeT value;
     if (parser.parseCustomAttributeWithFallback(value))
       return failure();
@@ -112,6 +131,9 @@ struct FieldParser<
     std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
                      std::optional<AttributeT>>> {
   static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) {
+    if constexpr (HasStaticDialectName<AttributeT>::value) {
+      parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
+    }
     AttributeT attr;
     OptionalParseResult result = parser.parseOptionalAttribute(attr);
     if (result.has_value()) {

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index e0d7f2a5dd0cb..cace1fefa43d6 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1464,15 +1464,3 @@ test.dialect_custom_format_fallback custom_format_fallback
 // Check that an op with an optional result parses f80 as type.
 // CHECK: test.format_optional_result_d_op : f80
 test.format_optional_result_d_op : f80
-
-
-// -----
-
-// This is a testing that a non-qualified attribute in a custom format
-// correctly preload the dialect before creating the attribute.
-#attr = #test.nested_polynomial<<1 + x**2>>
-// CHECK-lABLE: @parse_correctly
-llvm.func @parse_correctly() {
-  test.containing_int_polynomial_attr #attr
-  llvm.return
-}

diff  --git a/mlir/test/IR/parser_dialect_loading.mlir b/mlir/test/IR/parser_dialect_loading.mlir
new file mode 100644
index 0000000000000..b9c2d30cf3c98
--- /dev/null
+++ b/mlir/test/IR/parser_dialect_loading.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt -allow-unregistered-dialect --split-input-file %s | FileCheck %s
+
+// This is a testing that a non-qualified attribute in a custom format
+// correctly preload the dialect before creating the attribute.
+#attr = #test.nested_polynomial<poly=<1 + x**2>>
+// CHECK-LABEL: @parse_correctly
+llvm.func @parse_correctly() {
+  test.containing_int_polynomial_attr #attr
+  llvm.return
+}
+
+// -----
+
+#attr2 = #test.nested_polynomial2<poly=<1 + x**2>>
+// CHECK-LABEL: @parse_correctly_2
+llvm.func @parse_correctly_2() {
+  test.containing_int_polynomial_attr2 #attr2
+  llvm.return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 9e25acf5f5ba4..a0a1cd30ed8ae 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -356,8 +356,17 @@ def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
   let mnemonic = "nested_polynomial";
   let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
   let assemblyFormat = [{
-    `<` $poly `>`
+    `<` struct(params) `>`
   }];
 }
 
+def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
+  let mnemonic = "nested_polynomial2";
+  let parameters = (ins OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">:$poly);
+  let assemblyFormat = [{
+    `<` struct(params) `>`
+  }];
+}
+
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e1ec1428ee6d6..9450764fcb1d5 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -237,6 +237,11 @@ def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
   let assemblyFormat = "$attr attr-dict";
 }
 
+def ContainingIntPolynomialAttr2Op : TEST_Op<"containing_int_polynomial_attr2"> {
+  let arguments = (ins NestedPolynomialAttr2:$attr);
+  let assemblyFormat = "$attr attr-dict";
+}
+
 // A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
 // This tests both matching and generating float elements attributes.
 def UpdateFloatElementsAttr : Pat<

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index ea0d152cc94d4..8cc8314418104 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -89,6 +89,8 @@ class DefGen {
   void emitTopLevelDeclarations();
   /// Emit the function that returns the type or attribute name.
   void emitName();
+  /// Emit the dialect name as a static member variable.
+  void emitDialectName();
   /// Emit attribute or type builders.
   void emitBuilders();
   /// Emit a verifier for the def.
@@ -184,6 +186,8 @@ DefGen::DefGen(const AttrOrTypeDef &def)
     emitBuilders();
   // Emit the type name.
   emitName();
+  // Emit the dialect name.
+  emitDialectName();
   // Emit the verifier.
   if (storageCls && def.genVerifyDecl())
     emitVerifier();
@@ -281,6 +285,13 @@ void DefGen::emitName() {
   defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
 }
 
+void DefGen::emitDialectName() {
+  std::string decl =
+      strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
+             def.getDialect().getName());
+  defCls.declare<ExtraClassDeclaration>(std::move(decl));
+}
+
 void DefGen::emitBuilders() {
   if (!def.skipDefaultBuilders()) {
     emitDefaultBuilder();

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index ffd5a3913cf18..dacc20b6ba208 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -423,9 +423,11 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
         Dialect dialect(dialectInit->getDef());
         auto cppNamespace = dialect.getCppNamespace();
         std::string name = dialect.getCppClassName();
-        dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
-                          cppNamespace + "::" + name + ">();")
-                             .str();
+        if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
+          dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
+                            cppNamespace + "::" + name + ">();")
+                               .str();
+        }
       }
     }
   }


        


More information about the Mlir-commits mailing list