[Mlir-commits] [mlir] [mlir] load dialects for non-namespaced attrs (PR #96242)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 20 15:01:38 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

The mlir-translate tool calls into the parser without loading registered dependent dialects, and the parser only loads attributes if the fully-namespaced attribute is present in the textual IR. This causes parsing to break when an op has an attribute that prints/parses without the namespaced attribute.

---
Full diff: https://github.com/llvm/llvm-project/pull/96242.diff


6 Files Affected:

- (modified) mlir/test/Target/LLVMIR/test.mlir (+13) 
- (modified) mlir/test/lib/Dialect/Test/CMakeLists.txt (+3-2) 
- (modified) mlir/test/lib/Dialect/Test/TestAttrDefs.td (+9) 
- (modified) mlir/test/lib/Dialect/Test/TestAttributes.h (+1) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+6-1) 
- (modified) mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp (+22-2) 


``````````diff
diff --git a/mlir/test/Target/LLVMIR/test.mlir b/mlir/test/Target/LLVMIR/test.mlir
index 0ab1b7267d959..b0834526d7d1b 100644
--- a/mlir/test/Target/LLVMIR/test.mlir
+++ b/mlir/test/Target/LLVMIR/test.mlir
@@ -40,3 +40,16 @@ llvm.func @dialect_attr_translation_multi(%a: i64, %b: i64, %c: i64) -> i64 {
 // CHECK-DAG: ![[MD_ID_ADD]] = !{!"annotation_from_test: add"}
 // CHECK-DAG: ![[MD_ID_MUL]] = !{!"annotation_from_test: mul"}
 // CHECK-DAG: ![[MD_ID_RET]] = !{!"annotation_from_test: ret"}
+
+
+// -----
+
+// This is a regression test for a bug where, during an mlir-translate call the
+// parser would only load the dialect if the fully namespaced attribute was
+// present in the IR.
+#attr = #test.nested_polynomial<<1 + x**2>>
+// CHECK-lABLE: @parse_correctly
+llvm.func @parse_correctly() {
+  test.containing_int_polynomial_attr #attr
+  llvm.return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index fab8937809332..967101242e26b 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -16,8 +16,8 @@ mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs)
 add_public_tablegen_target(MLIRTestInterfaceIncGen)
 
 set(LLVM_TARGET_DEFINITIONS TestOps.td)
-mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls)
-mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=test)
+mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=test)
 add_public_tablegen_target(MLIRTestAttrDefIncGen)
 
 set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
@@ -86,6 +86,7 @@ add_mlir_library(MLIRTestDialect
   MLIRLinalgTransforms
   MLIRLLVMDialect
   MLIRPass
+  MLIRPolynomialDialect
   MLIRReduce
   MLIRTensorDialect
   MLIRTransformUtils
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 12635e107bd42..9e25acf5f5ba4 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -16,6 +16,7 @@
 // To get the test dialect definition.
 include "TestDialect.td"
 include "TestEnumDefs.td"
+include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.td"
 include "mlir/Dialect/Utils/StructuredOpsUtils.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
@@ -351,4 +352,12 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
   }];
 }
 
+def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
+  let mnemonic = "nested_polynomial";
+  let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
+  let assemblyFormat = [{
+    `<` $poly `>`
+  }];
+}
+
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index ef6eae51fdd62..7099bcf317294 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -17,6 +17,7 @@
 #include <tuple>
 
 #include "TestTraits.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Diagnostics.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index cce926d2cf63d..c77f76def1f06 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -232,6 +232,11 @@ def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
   );
 }
 
+def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
+  let arguments = (ins NestedPolynomialAttr:$attr);
+  let assemblyFormat = "$attr attr-dict";
+}
+
 // A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
 // This tests both matching and generating float elements attributes.
 def UpdateFloatElementsAttr : Pat<
@@ -2215,7 +2220,7 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
 def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
   let description = [{
     Reify a bound for the given index-typed value or dimension size of a shaped
-    value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set, 
+    value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
     `vscale_min` and `vscale_max` must be provided, which allows computing
     a bound in terms of "vector.vscale" for a given range of vscale.
   }];
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index abd1fbdaf8c64..50378feeb14ed 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -164,8 +164,9 @@ static const char *const parserErrorStr =
 /// {2}: Code template for printing an error.
 /// {3}: Name of the attribute or type.
 /// {4}: C++ class of the parameter.
+/// {5}: Optional code to preload the dialect for this variable.
 static const char *const variableParser = R"(
-// Parse variable '{0}'
+// Parse variable '{0}'{5}
 _result_{0} = {1};
 if (::mlir::failed(_result_{0})) {{
   {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
@@ -411,9 +412,28 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
   auto customParser = param.getParser();
   auto parser =
       customParser ? *customParser : StringRef(defaultParameterParser);
+
+  // If the variable points to a dialect specific entity (type of attribute),
+  // we force load the dialect now before trying to parse it.
+  std::string dialectLoading;
+  if (auto *defInit = dyn_cast<llvm::DefInit>(param.getDef())) {
+    auto *dialectValue = defInit->getDef()->getValue("dialect");
+    if (dialectValue) {
+      if (auto *dialectInit =
+              dyn_cast<llvm::DefInit>(dialectValue->getValue())) {
+        Dialect dialect(dialectInit->getDef());
+        auto cppNamespace = dialect.getCppNamespace();
+        std::string name = dialect.getCppClassName();
+        dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
+                          cppNamespace + "::" + name + ">();")
+                             .str();
+      }
+    }
+  }
   os << formatv(variableParser, param.getName(),
                 tgfmt(parser, &ctx, param.getCppStorageType()),
-                tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
+                tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(),
+                dialectLoading);
 }
 
 void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,

``````````

</details>


https://github.com/llvm/llvm-project/pull/96242


More information about the Mlir-commits mailing list