[Mlir-commits] [mlir] [mlir] Add the ability to override attribute parsing/printing in attr-dicts (PR #103304)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 13 09:11:21 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This adds a `parseNamedAttrFn` callback to `AsmParser::parseOptionalAttrDict()`.

If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns `failure()` which indicates that generic parsing should be used. Note: Returning a null Attribute from parseNamedAttrFn indicates a parser error.

It also adds `printNamedAttrFn` to `AsmPrinter::printOptionalAttrDict()`.

If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns `success()`, otherwise, it returns `failure()` which indicates that generic printing should be used.

---
Full diff: https://github.com/llvm/llvm-project/pull/103304.diff


5 Files Affected:

- (modified) mlir/include/mlir/IR/OpImplementation.h (+19-4) 
- (modified) mlir/lib/AsmParser/AsmParserImpl.h (+5-2) 
- (modified) mlir/lib/AsmParser/AttributeParser.cpp (+14-2) 
- (modified) mlir/lib/AsmParser/Parser.h (+3-1) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+30-15) 


``````````diff
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
   /// If the specified operation has attributes, print out an attribute
   /// dictionary with their values.  elidedAttrs allows the client to ignore
   /// specific well known attributes, commonly used if the attribute value is
-  /// printed some other way (like as a fixed operand).
+  /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+  /// provided the default printing can be overridden for a named attribute.
+  /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+  /// it returns `success()`, otherwise, it returns `failure()` which indicates
+  /// that generic printing should be used.
   virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                     ArrayRef<StringRef> elidedAttrs = {}) = 0;
+                                     ArrayRef<StringRef> elidedAttrs = {},
+                                     function_ref<LogicalResult(NamedAttribute)>
+                                         printNamedAttrFn = nullptr) = 0;
 
   /// If the specified operation has attributes, print out an attribute
   /// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
     return parseResult;
   }
 
-  /// Parse a named dictionary into 'result' if it is present.
-  virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+  /// Parse a named dictionary into 'result' if it is present. If
+  /// parseNamedAttrFn is provided the default parsing can be overridden for a
+  /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+  /// it can parse the attribute it returns the parsed attribute, otherwise, it
+  /// returns `failure()` which indicates that generic parsing should be used.
+  /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+  /// error.
+  virtual ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) = 0;
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
   /// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
   }
 
   /// Parse a named dictionary into 'result' if it is present.
-  ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+  ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) override {
     if (parser.getToken().isNot(Token::l_brace))
       return success();
-    return parser.parseAttributeDict(result);
+    return parser.parseAttributeDict(result, parseNamedAttrFn);
   }
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
 ///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+    NamedAttrList &attributes,
+    function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
   llvm::SmallDenseSet<StringAttr> seenKeys;
   auto parseElt = [&]() -> ParseResult {
     // The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
       return success();
     }
 
-    auto attr = parseAttribute();
+    Attribute attr = nullptr;
+    FailureOr<Attribute> customParsedAttribute;
+    // Try to parse with `printNamedAttrFn` callback.
+    if (parseNamedAttrFn &&
+        succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+      attr = *customParsedAttribute;
+    } else {
+      // Otherwise, use generic attribute parser.
+      attr = parseAttribute();
+    }
+
     if (!attr)
       return failure();
     attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
   }
 
   /// Parse an attribute dictionary.
-  ParseResult parseAttributeDict(NamedAttrList &attributes);
+  ParseResult parseAttributeDict(
+      NamedAttrList &attributes,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
 
   /// Parse a distinct attribute.
   Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
   void printDimensionList(ArrayRef<int64_t> shape);
 
 protected:
-  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {},
-                             bool withKeyword = false);
-  void printNamedAttribute(NamedAttribute attr);
+  void printOptionalAttrDict(
+      ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+      bool withKeyword = false,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+  void printNamedAttribute(
+      NamedAttribute attr,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
   void printTrailingLocation(Location loc, bool allowAlias = true);
   void printLocationInternal(LocationAttr loc, bool pretty = false,
                              bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
   /// Print the given set of attributes with names not included within
   /// 'elidedAttrs'.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    if (attrs.empty())
-      return;
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    (void)printNamedAttrFn;
     if (elidedAttrs.empty()) {
       for (const NamedAttribute &attr : attrs)
         printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Default([&](Type type) { return printDialectType(type); });
 }
 
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                             ArrayRef<StringRef> elidedAttrs,
-                                             bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+    ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+    bool withKeyword,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // If there are no attributes, then there is nothing to be done.
   if (attrs.empty())
     return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
 
     // Otherwise, print them all out in braces.
     os << " {";
-    interleaveComma(filteredAttrs,
-                    [&](NamedAttribute attr) { printNamedAttribute(attr); });
+    interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+      printNamedAttribute(attr, printNamedAttrFn);
+    });
     os << '}';
   };
 
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
   if (!filteredAttrs.empty())
     printFilteredAttributesFn(filteredAttrs);
 }
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+    NamedAttribute attr,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // Print the name without quotes if possible.
   ::printKeywordOrString(attr.getName().strref(), os);
 
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
     return;
 
   os << " = ";
+  if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+    /// If we print via the `printNamedAttrFn` callback skip printing.
+    return;
+  }
   printAttribute(attr.getValue());
 }
 
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
 
   /// Print an optional attribute dictionary with a given set of elided values.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    Impl::printOptionalAttrDict(attrs, elidedAttrs);
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+                                printNamedAttrFn);
   }
   void printOptionalAttrDictWithKeyword(
       ArrayRef<NamedAttribute> attrs,

``````````

</details>


https://github.com/llvm/llvm-project/pull/103304


More information about the Mlir-commits mailing list