[Mlir-commits] [mlir] [mlir] Add the ability to override attribute parsing/printing in attr-dicts (PR #103304)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Aug 14 08:00:54 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/103304
>From 06aefb23c1feb4ec7d48697d2852d2a3b33de0aa Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 13 Aug 2024 12:08:52 +0000
Subject: [PATCH 1/3] [mlir] Add the ability to override attribute
parsing/printing in attr-dicts
This adds a `parseNamedAttrFn` callback to
`AsmParser::parseOptionalAttrDict()`.
If parseNamedAttrFn is provided the default parsing can be overridden
for a named attribute. parseNamedAttrFn is passed the name of an
attribute, if it can parse the attribute it returns the parsed
attribute, otherwise, it returns `failure()` which indicates that
generic parsing should be used. Note: Returning a null Attribute from
parseNamedAttrFn indicates a parser error.
It also adds `printNamedAttrFn` to
`AsmPrinter::printOptionalAttrDict()`.
If printNamedAttrFn is provided the default printing can be overridden
for a named attribute. printNamedAttrFn is passed a NamedAttribute, if
it prints the attribute it returns `success()`, otherwise, it returns
`failure()` which indicates that generic printing should be used.
---
mlir/include/mlir/IR/OpImplementation.h | 23 ++++++++++---
mlir/lib/AsmParser/AsmParserImpl.h | 7 ++--
mlir/lib/AsmParser/AttributeParser.cpp | 16 +++++++--
mlir/lib/AsmParser/Parser.h | 4 ++-
mlir/lib/IR/AsmPrinter.cpp | 45 ++++++++++++++++---------
5 files changed, 71 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
- /// printed some other way (like as a fixed operand).
+ /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+ /// provided the default printing can be overridden for a named attribute.
+ /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+ /// it returns `success()`, otherwise, it returns `failure()` which indicates
+ /// that generic printing should be used.
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) = 0;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) = 0;
/// If the specified operation has attributes, print out an attribute
/// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
return parseResult;
}
- /// Parse a named dictionary into 'result' if it is present.
- virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+ /// Parse a named dictionary into 'result' if it is present. If
+ /// parseNamedAttrFn is provided the default parsing can be overridden for a
+ /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+ /// it can parse the attribute it returns the parsed attribute, otherwise, it
+ /// returns `failure()` which indicates that generic parsing should be used.
+ /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+ /// error.
+ virtual ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) = 0;
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
}
/// Parse a named dictionary into 'result' if it is present.
- ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+ ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
- return parser.parseAttributeDict(result);
+ return parser.parseAttributeDict(result, parseNamedAttrFn);
}
/// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
/// | `{` attribute-entry (`,` attribute-entry)* `}`
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
llvm::SmallDenseSet<StringAttr> seenKeys;
auto parseElt = [&]() -> ParseResult {
// The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
}
- auto attr = parseAttribute();
+ Attribute attr = nullptr;
+ FailureOr<Attribute> customParsedAttribute;
+ // Try to parse with `printNamedAttrFn` callback.
+ if (parseNamedAttrFn &&
+ succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+ attr = *customParsedAttribute;
+ } else {
+ // Otherwise, use generic attribute parser.
+ attr = parseAttribute();
+ }
+
if (!attr)
return failure();
attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
}
/// Parse an attribute dictionary.
- ParseResult parseAttributeDict(NamedAttrList &attributes);
+ ParseResult parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
/// Parse a distinct attribute.
Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
void printDimensionList(ArrayRef<int64_t> shape);
protected:
- void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {},
- bool withKeyword = false);
- void printNamedAttribute(NamedAttribute attr);
+ void printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+ bool withKeyword = false,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+ void printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false,
bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- if (attrs.empty())
- return;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ (void)printNamedAttrFn;
if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs)
printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs,
- bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+ bool withKeyword,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
// Otherwise, print them all out in braces.
os << " {";
- interleaveComma(filteredAttrs,
- [&](NamedAttribute attr) { printNamedAttribute(attr); });
+ interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+ printNamedAttribute(attr, printNamedAttrFn);
+ });
os << '}';
};
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
return;
os << " = ";
+ if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+ /// If we print via the `printNamedAttrFn` callback skip printing.
+ return;
+ }
printAttribute(attr.getValue());
}
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- Impl::printOptionalAttrDict(attrs, elidedAttrs);
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+ printNamedAttrFn);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
>From 13e51d18ab35c5853a3ec2478d96b7e090d18136 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 14 Aug 2024 10:20:15 +0000
Subject: [PATCH 2/3] Add test
---
.../IR/custom-attr-syntax-in-attr-dict.mlir | 30 ++++++++++++++++++
...verride_attribute_syntax_in_attr_dict.mlir | 0
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 31 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 10 ++++++
4 files changed, 71 insertions(+)
create mode 100644 mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir
create mode 100644 mlir/test/IR/override_attribute_syntax_in_attr_dict.mlir
diff --git a/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir b/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir
new file mode 100644
index 00000000000000..5c62430edde223
--- /dev/null
+++ b/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s | FileCheck %s --check-prefix=CHECK-ROUNDTRIP
+// RUN: mlir-opt %s -mlir-print-op-generic | FileCheck %s --check-prefix=CHECK-GENERIC-SYNTAX
+
+/// This file tetss that "custom_dense_array" (which is a DenseArrayAttribute
+/// stored within the attr-dict) is parsed and printed with the "pretty" array
+/// syntax (i.e. `[1, 2, 3, 4]`), rather than with the generic dense array
+/// syntax (`array<i64: 1, 2, 3, 4>`).
+///
+/// This is done by injecting custom parsing and printing callbacks into
+/// parseOptionalAttrDict() and printOptionalAttrDict().
+
+func.func @custom_attr_dict_syntax() {
+ // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]}
+ // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> : () -> ()
+ test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]}
+
+ // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]}
+ // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {another_attr = "foo"} : () -> ()
+ test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]}
+
+ // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]}
+ // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {default_array = [1, 2, 3, 4]} : () -> ()
+ test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]}
+
+ // CHECK-ROUND-TRIP: test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array<i64: 1, 2, 3, 4>, custom_dense_array = [1, 2, 3, 4]}
+ // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {default_dense_array = array<i64: 1, 2, 3, 4>} : () -> ()
+ test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array<i64: 1, 2, 3, 4>, custom_dense_array = [1, 2, 3, 4]}
+
+ return
+}
diff --git a/mlir/test/IR/override_attribute_syntax_in_attr_dict.mlir b/mlir/test/IR/override_attribute_syntax_in_attr_dict.mlir
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index fbaa102d3e33cc..a6994c402bd3e6 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -141,6 +141,37 @@ void AffineScopeOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
+//===----------------------------------------------------------------------===//
+// CustomAttrParseAndPrintInAttrDict
+//===----------------------------------------------------------------------===//
+
+ParseResult CustomAttrParseAndPrintInAttrDict::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parser.parseOptionalAttrDict(
+ result.attributes, [&](StringRef name) -> FailureOr<Attribute> {
+ // Override the parsing for the "dense_array" attribute in the
+ // attr-dict. Rather than parsing it as array<i64: 0, 1, 2, ...>, parse
+ // it as [0, 1, 2, ...] (i.e. using the standard array syntax).
+ if (name != getCustomDenseArrayAttrName(result.name))
+ return failure();
+ return DenseI64ArrayAttr::parse(parser, {});
+ });
+}
+
+void CustomAttrParseAndPrintInAttrDict::print(OpAsmPrinter &p) {
+ p.printOptionalAttrDict((*this)->getAttrs(), {},
+ [&](NamedAttribute attr) -> LogicalResult {
+ // Override the printing for the "dense_array"
+ // attribute. Rather than printing it as array<i64:
+ // 0, 1, 2, ...>, print it as [0, 1, 2 ...] (i.e.
+ // using standard array syntax).
+ if (attr.getName() != getCustomDenseArrayAttrName())
+ return failure();
+ cast<DenseI64ArrayAttr>(attr.getValue()).print(p);
+ return success();
+ });
+}
+
//===----------------------------------------------------------------------===//
// TestRemoveOpWithInnerOps
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2b55bff3538d39..1c8cdc3fef74ec 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2083,6 +2083,16 @@ def OptionalCustomAttrOp : TEST_Op<"optional_custom_attr"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Test overriding attribute parsing/printing in the attr-dict via callbacks
+// on parseOptionalAttrDict() and printOptionalAttrDict().
+
+def CustomAttrParseAndPrintInAttrDict : TEST_Op<"custom_attr_parse_and_print_in_attr_dict">
+{
+ let arguments = (ins DenseI64ArrayAttr:$custom_dense_array);
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Test OpAsmInterface.
>From e1e318dc63430205774ea9ec622fcde6fecb22db Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 14 Aug 2024 14:58:26 +0000
Subject: [PATCH 3/3] Fixup comment
---
mlir/lib/IR/AsmPrinter.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index cd9f70c8868b83..d931fb62cd7bdb 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2738,7 +2738,8 @@ void AsmPrinter::Impl::printNamedAttribute(
os << " = ";
if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
- /// If we print via the `printNamedAttrFn` callback skip printing.
+ /// If we print via the `printNamedAttrFn` callback, skip the generic
+ /// attribute printing (i.e. the call to `printAttribute`).
return;
}
printAttribute(attr.getValue());
More information about the Mlir-commits
mailing list