[Mlir-commits] [mlir] [mlir] Add `Print(Attr|Type)Qualified` trait (PR #80420)

Markus Böck llvmlistbot at llvm.org
Fri Feb 2 04:12:37 PST 2024


https://github.com/zero9178 updated https://github.com/llvm/llvm-project/pull/80420

>From 94c47995dff5a3a57b7066676dfe40b1caf713dc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Fri, 2 Feb 2024 11:57:16 +0100
Subject: [PATCH 1/2] [mlir] Add `Print(Attr|Type)Qualified` trait

This PR adds a new trait to attributes and types that force the use of the qualified syntax for attributes and types. More concretely, any attribute or type with the trait must be parsed and printed with the `dialect.mnemonic` prefix.

The motivation for this PR is the dependent PR where it is used to retain backwards-compatibility of syntax, but downstream projects may also use the trait if the subjectively prefer the verbose syntax.
---
 mlir/include/mlir/IR/AttrTypeBase.td       |  8 ++++
 mlir/include/mlir/IR/Attributes.h          |  8 +++-
 mlir/include/mlir/IR/OpImplementation.h    | 44 ++++++++++++++++------
 mlir/include/mlir/IR/Types.h               |  5 +++
 mlir/test/IR/always-qualified-trait.mlir   |  4 ++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td |  9 +++++
 mlir/test/lib/Dialect/Test/TestOps.td      |  6 +++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td |  9 +++++
 8 files changed, 80 insertions(+), 13 deletions(-)
 create mode 100644 mlir/test/IR/always-qualified-trait.mlir

diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd4..c371ce9e515d8 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -38,6 +38,10 @@ class ParamNativeAttrTrait<string prop, string params>
 class GenInternalAttrTrait<string prop> : GenInternalTrait<prop, "Attribute">;
 class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
 
+// Trait used to tell the printer and parser to always print and parse
+// instances of the attribute as if it occurs within a `qualified` directive.
+def PrintAttrQualified : NativeAttrTrait<"PrintQualified">;
+
 //===----------------------------------------------------------------------===//
 // TypeTrait definitions
 //===----------------------------------------------------------------------===//
@@ -56,6 +60,10 @@ class ParamNativeTypeTrait<string prop, string params>
 class GenInternalTypeTrait<string prop> : GenInternalTrait<prop, "Type">;
 class PredTypeTrait<string descr, Pred pred> : PredTrait<descr, pred>;
 
+// Trait used to tell the printer and parser to always print and parse
+// instances of the type as if it occurs within a `qualified` directive.
+def PrintTypeQualified : NativeTypeTrait<"PrintQualified">;
+
 //===----------------------------------------------------------------------===//
 // Builders
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index cc0cee6a31183..2122e5f8e1359 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -317,12 +317,18 @@ class AttributeInterface
 // Core AttributeTrait
 //===----------------------------------------------------------------------===//
 
+namespace AttributeTrait {
+
 /// This trait is used to determine if an attribute is mutable or not. It is
 /// attached on an attribute if the corresponding ImplType defines a `mutate`
 /// function with proper signature.
-namespace AttributeTrait {
 template <typename ConcreteType>
 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// Trait used to tell the printer and parser to always print and parse
+/// instances of the attribute as if it occurs within a `qualified` directive.
+template <typename ConcreteAttr>
+struct PrintQualified : TraitBase<ConcreteAttr, PrintQualified> {};
 } // namespace AttributeTrait
 
 } // namespace mlir.
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5c..402399cf29665 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -132,12 +132,20 @@ class AsmPrinter {
   using detect_has_print_method =
       llvm::is_detected<has_print_method, AttrOrType>;
 
+  /// Constexpr bool that is true if `AttrOrType` should be printed with the
+  /// dialect prefix stripped.
+  template <typename AttrOrType>
+  constexpr static bool shouldPrintStripped =
+      detect_has_print_method<AttrOrType>::value &&
+      (!std::is_base_of_v<AttributeTrait::PrintQualified<AttrOrType>,
+                          AttrOrType> &&
+       !std::is_base_of_v<TypeTrait::PrintQualified<AttrOrType>, AttrOrType>);
+
   /// Print the provided attribute in the context of an operation custom
   /// printer/parser: this will invoke directly the print method on the
   /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
   template <typename AttrOrType,
-            std::enable_if_t<detect_has_print_method<AttrOrType>::value>
-                *sfinae = nullptr>
+            std::enable_if_t<shouldPrintStripped<AttrOrType>> *sfinae = nullptr>
   void printStrippedAttrOrType(AttrOrType attrOrType) {
     if (succeeded(printAlias(attrOrType)))
       return;
@@ -158,8 +166,7 @@ class AsmPrinter {
   /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
   /// most cases.
   template <typename AttrOrType,
-            std::enable_if_t<detect_has_print_method<AttrOrType>::value>
-                *sfinae = nullptr>
+            std::enable_if_t<shouldPrintStripped<AttrOrType>> *sfinae = nullptr>
   void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
     llvm::interleaveComma(
         attrOrTypes, getStream(),
@@ -170,8 +177,7 @@ class AsmPrinter {
   /// custom printer in the case where the attribute does not define a print
   /// method.
   template <typename AttrOrType,
-            std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
-                *sfinae = nullptr>
+      std::enable_if_t<!shouldPrintStripped<AttrOrType>> *sfinae = nullptr>
   void printStrippedAttrOrType(AttrOrType attrOrType) {
     *this << attrOrType;
   }
@@ -980,12 +986,19 @@ class AsmParser {
   template <typename AttrType>
   using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
 
+  /// Constexpr bool that is true if `AttrType` can be parsed with the dialect
+  /// prefix stripped.
+  template <typename AttrType>
+  constexpr static bool shouldParseAttrStripped =
+      detect_has_parse_method<AttrType>::value &&
+      !std::is_base_of_v<AttributeTrait::PrintQualified<AttrType>, AttrType>;
+
   /// Parse a custom attribute of a given type unless the next token is `#`, in
   /// which case the generic parser is invoked. The parsed attribute is
   /// populated in `result` and also added to the specified attribute list with
   /// the specified name.
   template <typename AttrType>
-  std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+  std::enable_if_t<shouldParseAttrStripped<AttrType>, ParseResult>
   parseCustomAttributeWithFallback(AttrType &result, Type type,
                                    StringRef attrName, NamedAttrList &attrs) {
     SMLoc loc = getCurrentLocation();
@@ -1012,7 +1025,7 @@ class AsmParser {
 
   /// SFINAE parsing method for Attribute that don't implement a parse method.
   template <typename AttrType>
-  std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+  std::enable_if_t<!shouldParseAttrStripped<AttrType>, ParseResult>
   parseCustomAttributeWithFallback(AttrType &result, Type type,
                                    StringRef attrName, NamedAttrList &attrs) {
     return parseAttribute(result, type, attrName, attrs);
@@ -1022,7 +1035,7 @@ class AsmParser {
   /// which case the generic parser is invoked. The parsed attribute is
   /// populated in `result`.
   template <typename AttrType>
-  std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+  std::enable_if_t<shouldParseAttrStripped<AttrType>, ParseResult>
   parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
     SMLoc loc = getCurrentLocation();
 
@@ -1044,7 +1057,7 @@ class AsmParser {
 
   /// SFINAE parsing method for Attribute that don't implement a parse method.
   template <typename AttrType>
-  std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+  std::enable_if_t<!shouldParseAttrStripped<AttrType>, ParseResult>
   parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
     return parseAttribute(result, type);
   }
@@ -1213,11 +1226,18 @@ class AsmParser {
   using detect_type_has_parse_method =
       llvm::is_detected<type_has_parse_method, TypeT>;
 
+  /// Constexpr bool that is true if `TypeT` can be parsed with the dialect
+  /// prefix stripped.
+  template <typename TypeT>
+  constexpr static bool shouldParseTypeStripped =
+      detect_type_has_parse_method<TypeT>::value &&
+      !std::is_base_of_v<TypeTrait::PrintQualified<TypeT>, TypeT>;
+
   /// Parse a custom Type of a given type unless the next token is `#`, in
   /// which case the generic parser is invoked. The parsed Type is
   /// populated in `result`.
   template <typename TypeT>
-  std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
+  std::enable_if_t<shouldParseTypeStripped<TypeT>, ParseResult>
   parseCustomTypeWithFallback(TypeT &result) {
     SMLoc loc = getCurrentLocation();
 
@@ -1238,7 +1258,7 @@ class AsmParser {
 
   /// SFINAE parsing method for Type that don't implement a parse method.
   template <typename TypeT>
-  std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
+  std::enable_if_t<!shouldParseTypeStripped<TypeT>, ParseResult>
   parseCustomTypeWithFallback(TypeT &result) {
     return parseType(result);
   }
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 46bb733101c12..5c647504f7a76 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -304,6 +304,11 @@ class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
 namespace TypeTrait {
 template <typename ConcreteType>
 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// Trait used to tell the printer and parser to always print and parse
+/// instances of the type as if it occurs within a `qualified` directive.
+template <typename ConcreteType>
+struct PrintQualified : TraitBase<ConcreteType, PrintQualified> {};
 } // namespace TypeTrait
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/always-qualified-trait.mlir b/mlir/test/IR/always-qualified-trait.mlir
new file mode 100644
index 0000000000000..c029cb20df340
--- /dev/null
+++ b/mlir/test/IR/always-qualified-trait.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK: test.would_print_unqualified #test.always_qualified<5> -> !test.always_qualified<7>
+%0 = test.would_print_unqualified #test.always_qualified<5> -> !test.always_qualified<7>
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4e..5e7f8e290dbdb 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -340,4 +340,13 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
   }];
 }
 
+def TestAlwaysQualifiedAttr : Test_Attr<"TestAlwaysQualified",
+  [PrintAttrQualified]> {
+  let mnemonic = "always_qualified";
+  let parameters = (ins "int":$value);
+  let assemblyFormat = [{
+    `<` $value `>`
+  }];
+}
+
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 91ce0af9cd7fd..e62322e8ad0d0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3056,4 +3056,10 @@ def TestOpOptionallyImplementingInterface
   let arguments = (ins BoolAttr:$implementsInterface);
 }
 
+def TestOpWouldPrintUnqualified : TEST_Op<"would_print_unqualified"> {
+  let arguments = (ins TestAlwaysQualifiedAttr:$attr);
+  let results = (outs TestAlwaysQualifiedType:$result);
+  let assemblyFormat = "$attr `->` type($result) attr-dict";
+}
+
 #endif // TEST_OPS
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 1957845c842f2..681ab3de440a9 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -391,4 +391,13 @@ def TestRecursiveAlias
   }];
 }
 
+def TestAlwaysQualifiedType : Test_Type<"TestAlwaysQualified",
+  [PrintTypeQualified]> {
+  let mnemonic = "always_qualified";
+  let parameters = (ins "int":$value);
+  let assemblyFormat = [{
+    `<` $value `>`
+  }];
+}
+
 #endif // TEST_TYPEDEFS

>From f9f2f00dc37d979b459c7d9d1dc6f52ec3466206 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Fri, 2 Feb 2024 13:13:01 +0100
Subject: [PATCH 2/2] clang-format

---
 mlir/include/mlir/IR/OpImplementation.h | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 402399cf29665..50e6cc59ca459 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -176,7 +176,8 @@ class AsmPrinter {
   /// SFINAE for printing the provided attribute in the context of an operation
   /// custom printer in the case where the attribute does not define a print
   /// method.
-  template <typename AttrOrType,
+  template <
+      typename AttrOrType,
       std::enable_if_t<!shouldPrintStripped<AttrOrType>> *sfinae = nullptr>
   void printStrippedAttrOrType(AttrOrType attrOrType) {
     *this << attrOrType;



More information about the Mlir-commits mailing list