[Mlir-commits] [mlir] [mlir] Add the ability to override attribute parsing/printing in attr-dicts (PR #103304)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Aug 14 08:00:54 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/103304

>From 06aefb23c1feb4ec7d48697d2852d2a3b33de0aa Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 13 Aug 2024 12:08:52 +0000
Subject: [PATCH 1/3] [mlir] Add the ability to override attribute
 parsing/printing in attr-dicts

This adds a `parseNamedAttrFn` callback to
`AsmParser::parseOptionalAttrDict()`.

If parseNamedAttrFn is provided the default parsing can be overridden
for a named attribute. parseNamedAttrFn is passed the name of an
attribute, if it can parse the attribute it returns the parsed
attribute, otherwise, it returns `failure()` which indicates that
generic parsing should be used. Note: Returning a null Attribute from
parseNamedAttrFn indicates a parser error.

It also adds `printNamedAttrFn` to
`AsmPrinter::printOptionalAttrDict()`.

If printNamedAttrFn is provided the default printing can be overridden
for a named attribute. printNamedAttrFn is passed a NamedAttribute, if
it prints the attribute it returns `success()`, otherwise, it returns
`failure()` which indicates that generic printing should be used.
---
 mlir/include/mlir/IR/OpImplementation.h | 23 ++++++++++---
 mlir/lib/AsmParser/AsmParserImpl.h      |  7 ++--
 mlir/lib/AsmParser/AttributeParser.cpp  | 16 +++++++--
 mlir/lib/AsmParser/Parser.h             |  4 ++-
 mlir/lib/IR/AsmPrinter.cpp              | 45 ++++++++++++++++---------
 5 files changed, 71 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
   /// If the specified operation has attributes, print out an attribute
   /// dictionary with their values.  elidedAttrs allows the client to ignore
   /// specific well known attributes, commonly used if the attribute value is
-  /// printed some other way (like as a fixed operand).
+  /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+  /// provided the default printing can be overridden for a named attribute.
+  /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+  /// it returns `success()`, otherwise, it returns `failure()` which indicates
+  /// that generic printing should be used.
   virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                     ArrayRef<StringRef> elidedAttrs = {}) = 0;
+                                     ArrayRef<StringRef> elidedAttrs = {},
+                                     function_ref<LogicalResult(NamedAttribute)>
+                                         printNamedAttrFn = nullptr) = 0;
 
   /// If the specified operation has attributes, print out an attribute
   /// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
     return parseResult;
   }
 
-  /// Parse a named dictionary into 'result' if it is present.
-  virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+  /// Parse a named dictionary into 'result' if it is present. If
+  /// parseNamedAttrFn is provided the default parsing can be overridden for a
+  /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+  /// it can parse the attribute it returns the parsed attribute, otherwise, it
+  /// returns `failure()` which indicates that generic parsing should be used.
+  /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+  /// error.
+  virtual ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) = 0;
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
   /// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
   }
 
   /// Parse a named dictionary into 'result' if it is present.
-  ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+  ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) override {
     if (parser.getToken().isNot(Token::l_brace))
       return success();
-    return parser.parseAttributeDict(result);
+    return parser.parseAttributeDict(result, parseNamedAttrFn);
   }
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
 ///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+    NamedAttrList &attributes,
+    function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
   llvm::SmallDenseSet<StringAttr> seenKeys;
   auto parseElt = [&]() -> ParseResult {
     // The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
       return success();
     }
 
-    auto attr = parseAttribute();
+    Attribute attr = nullptr;
+    FailureOr<Attribute> customParsedAttribute;
+    // Try to parse with `printNamedAttrFn` callback.
+    if (parseNamedAttrFn &&
+        succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+      attr = *customParsedAttribute;
+    } else {
+      // Otherwise, use generic attribute parser.
+      attr = parseAttribute();
+    }
+
     if (!attr)
       return failure();
     attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
   }
 
   /// Parse an attribute dictionary.
-  ParseResult parseAttributeDict(NamedAttrList &attributes);
+  ParseResult parseAttributeDict(
+      NamedAttrList &attributes,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
 
   /// Parse a distinct attribute.
   Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
   void printDimensionList(ArrayRef<int64_t> shape);
 
 protected:
-  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {},
-                             bool withKeyword = false);
-  void printNamedAttribute(NamedAttribute attr);
+  void printOptionalAttrDict(
+      ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+      bool withKeyword = false,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+  void printNamedAttribute(
+      NamedAttribute attr,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
   void printTrailingLocation(Location loc, bool allowAlias = true);
   void printLocationInternal(LocationAttr loc, bool pretty = false,
                              bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
   /// Print the given set of attributes with names not included within
   /// 'elidedAttrs'.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    if (attrs.empty())
-      return;
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    (void)printNamedAttrFn;
     if (elidedAttrs.empty()) {
       for (const NamedAttribute &attr : attrs)
         printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Default([&](Type type) { return printDialectType(type); });
 }
 
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                             ArrayRef<StringRef> elidedAttrs,
-                                             bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+    ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+    bool withKeyword,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // If there are no attributes, then there is nothing to be done.
   if (attrs.empty())
     return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
 
     // Otherwise, print them all out in braces.
     os << " {";
-    interleaveComma(filteredAttrs,
-                    [&](NamedAttribute attr) { printNamedAttribute(attr); });
+    interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+      printNamedAttribute(attr, printNamedAttrFn);
+    });
     os << '}';
   };
 
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
   if (!filteredAttrs.empty())
     printFilteredAttributesFn(filteredAttrs);
 }
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+    NamedAttribute attr,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // Print the name without quotes if possible.
   ::printKeywordOrString(attr.getName().strref(), os);
 
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
     return;
 
   os << " = ";
+  if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+    /// If we print via the `printNamedAttrFn` callback skip printing.
+    return;
+  }
   printAttribute(attr.getValue());
 }
 
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
 
   /// Print an optional attribute dictionary with a given set of elided values.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    Impl::printOptionalAttrDict(attrs, elidedAttrs);
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+                                printNamedAttrFn);
   }
   void printOptionalAttrDictWithKeyword(
       ArrayRef<NamedAttribute> attrs,

>From 13e51d18ab35c5853a3ec2478d96b7e090d18136 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 14 Aug 2024 10:20:15 +0000
Subject: [PATCH 2/3] Add test

---
 .../IR/custom-attr-syntax-in-attr-dict.mlir   | 30 ++++++++++++++++++
 ...verride_attribute_syntax_in_attr_dict.mlir |  0
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 31 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 10 ++++++
 4 files changed, 71 insertions(+)
 create mode 100644 mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir
 create mode 100644 mlir/test/IR/override_attribute_syntax_in_attr_dict.mlir

diff --git a/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir b/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir
new file mode 100644
index 00000000000000..5c62430edde223
--- /dev/null
+++ b/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s | FileCheck %s --check-prefix=CHECK-ROUNDTRIP
+// RUN: mlir-opt %s -mlir-print-op-generic | FileCheck %s --check-prefix=CHECK-GENERIC-SYNTAX
+
+/// This file tetss that "custom_dense_array" (which is a DenseArrayAttribute
+/// stored within the attr-dict) is parsed and printed with the "pretty" array
+/// syntax (i.e. `[1, 2, 3, 4]`), rather than with the generic dense array
+/// syntax (`array<i64: 1, 2, 3, 4>`).
+///
+/// This is done by injecting custom parsing and printing callbacks into
+/// parseOptionalAttrDict() and printOptionalAttrDict().
+
+func.func @custom_attr_dict_syntax() {
+  // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]}
+  // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> : () -> ()
+  test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]}
+
+  // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]}
+  // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {another_attr = "foo"} : () -> ()
+  test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]}
+
+  // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]}
+  // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {default_array = [1, 2, 3, 4]} : () -> ()
+  test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]}
+
+  // CHECK-ROUND-TRIP: test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array<i64: 1, 2, 3, 4>, custom_dense_array = [1, 2, 3, 4]}
+  // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {default_dense_array = array<i64: 1, 2, 3, 4>} : () -> ()
+  test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array<i64: 1, 2, 3, 4>, custom_dense_array = [1, 2, 3, 4]}
+
+  return
+}
diff --git a/mlir/test/IR/override_attribute_syntax_in_attr_dict.mlir b/mlir/test/IR/override_attribute_syntax_in_attr_dict.mlir
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index fbaa102d3e33cc..a6994c402bd3e6 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -141,6 +141,37 @@ void AffineScopeOp::print(OpAsmPrinter &p) {
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
 }
 
+//===----------------------------------------------------------------------===//
+// CustomAttrParseAndPrintInAttrDict
+//===----------------------------------------------------------------------===//
+
+ParseResult CustomAttrParseAndPrintInAttrDict::parse(OpAsmParser &parser,
+                                                     OperationState &result) {
+  return parser.parseOptionalAttrDict(
+      result.attributes, [&](StringRef name) -> FailureOr<Attribute> {
+        // Override the parsing for the "dense_array" attribute in the
+        // attr-dict. Rather than parsing it as array<i64: 0, 1, 2, ...>, parse
+        // it as [0, 1, 2, ...] (i.e. using the standard array syntax).
+        if (name != getCustomDenseArrayAttrName(result.name))
+          return failure();
+        return DenseI64ArrayAttr::parse(parser, {});
+      });
+}
+
+void CustomAttrParseAndPrintInAttrDict::print(OpAsmPrinter &p) {
+  p.printOptionalAttrDict((*this)->getAttrs(), {},
+                          [&](NamedAttribute attr) -> LogicalResult {
+                            // Override the printing for the "dense_array"
+                            // attribute. Rather than printing it as array<i64:
+                            // 0, 1, 2, ...>, print it as [0, 1, 2 ...] (i.e.
+                            // using standard array syntax).
+                            if (attr.getName() != getCustomDenseArrayAttrName())
+                              return failure();
+                            cast<DenseI64ArrayAttr>(attr.getValue()).print(p);
+                            return success();
+                          });
+}
+
 //===----------------------------------------------------------------------===//
 // TestRemoveOpWithInnerOps
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2b55bff3538d39..1c8cdc3fef74ec 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2083,6 +2083,16 @@ def OptionalCustomAttrOp : TEST_Op<"optional_custom_attr"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test overriding attribute parsing/printing in the attr-dict via callbacks
+// on parseOptionalAttrDict() and printOptionalAttrDict().
+
+def CustomAttrParseAndPrintInAttrDict : TEST_Op<"custom_attr_parse_and_print_in_attr_dict">
+{
+  let arguments = (ins DenseI64ArrayAttr:$custom_dense_array);
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Test OpAsmInterface.
 

>From e1e318dc63430205774ea9ec622fcde6fecb22db Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 14 Aug 2024 14:58:26 +0000
Subject: [PATCH 3/3] Fixup comment

---
 mlir/lib/IR/AsmPrinter.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index cd9f70c8868b83..d931fb62cd7bdb 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2738,7 +2738,8 @@ void AsmPrinter::Impl::printNamedAttribute(
 
   os << " = ";
   if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
-    /// If we print via the `printNamedAttrFn` callback skip printing.
+    /// If we print via the `printNamedAttrFn` callback, skip the generic
+    /// attribute printing (i.e. the call to `printAttribute`).
     return;
   }
   printAttribute(attr.getValue());



More information about the Mlir-commits mailing list