[Mlir-commits] [mlir] [mlir] Add the ability to override attribute parsing/printing in attr-dicts (PR #103304)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Aug 13 09:10:44 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/103304
This adds a `parseNamedAttrFn` callback to `AsmParser::parseOptionalAttrDict()`.
If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns `failure()` which indicates that generic parsing should be used. Note: Returning a null Attribute from parseNamedAttrFn indicates a parser error.
It also adds `printNamedAttrFn` to `AsmPrinter::printOptionalAttrDict()`.
If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns `success()`, otherwise, it returns `failure()` which indicates that generic printing should be used.
>From 06aefb23c1feb4ec7d48697d2852d2a3b33de0aa Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 13 Aug 2024 12:08:52 +0000
Subject: [PATCH] [mlir] Add the ability to override attribute parsing/printing
in attr-dicts
This adds a `parseNamedAttrFn` callback to
`AsmParser::parseOptionalAttrDict()`.
If parseNamedAttrFn is provided the default parsing can be overridden
for a named attribute. parseNamedAttrFn is passed the name of an
attribute, if it can parse the attribute it returns the parsed
attribute, otherwise, it returns `failure()` which indicates that
generic parsing should be used. Note: Returning a null Attribute from
parseNamedAttrFn indicates a parser error.
It also adds `printNamedAttrFn` to
`AsmPrinter::printOptionalAttrDict()`.
If printNamedAttrFn is provided the default printing can be overridden
for a named attribute. printNamedAttrFn is passed a NamedAttribute, if
it prints the attribute it returns `success()`, otherwise, it returns
`failure()` which indicates that generic printing should be used.
---
mlir/include/mlir/IR/OpImplementation.h | 23 ++++++++++---
mlir/lib/AsmParser/AsmParserImpl.h | 7 ++--
mlir/lib/AsmParser/AttributeParser.cpp | 16 +++++++--
mlir/lib/AsmParser/Parser.h | 4 ++-
mlir/lib/IR/AsmPrinter.cpp | 45 ++++++++++++++++---------
5 files changed, 71 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
- /// printed some other way (like as a fixed operand).
+ /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+ /// provided the default printing can be overridden for a named attribute.
+ /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+ /// it returns `success()`, otherwise, it returns `failure()` which indicates
+ /// that generic printing should be used.
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) = 0;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) = 0;
/// If the specified operation has attributes, print out an attribute
/// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
return parseResult;
}
- /// Parse a named dictionary into 'result' if it is present.
- virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+ /// Parse a named dictionary into 'result' if it is present. If
+ /// parseNamedAttrFn is provided the default parsing can be overridden for a
+ /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+ /// it can parse the attribute it returns the parsed attribute, otherwise, it
+ /// returns `failure()` which indicates that generic parsing should be used.
+ /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+ /// error.
+ virtual ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) = 0;
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
}
/// Parse a named dictionary into 'result' if it is present.
- ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+ ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
- return parser.parseAttributeDict(result);
+ return parser.parseAttributeDict(result, parseNamedAttrFn);
}
/// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
/// | `{` attribute-entry (`,` attribute-entry)* `}`
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
llvm::SmallDenseSet<StringAttr> seenKeys;
auto parseElt = [&]() -> ParseResult {
// The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
}
- auto attr = parseAttribute();
+ Attribute attr = nullptr;
+ FailureOr<Attribute> customParsedAttribute;
+ // Try to parse with `printNamedAttrFn` callback.
+ if (parseNamedAttrFn &&
+ succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+ attr = *customParsedAttribute;
+ } else {
+ // Otherwise, use generic attribute parser.
+ attr = parseAttribute();
+ }
+
if (!attr)
return failure();
attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
}
/// Parse an attribute dictionary.
- ParseResult parseAttributeDict(NamedAttrList &attributes);
+ ParseResult parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
/// Parse a distinct attribute.
Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
void printDimensionList(ArrayRef<int64_t> shape);
protected:
- void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {},
- bool withKeyword = false);
- void printNamedAttribute(NamedAttribute attr);
+ void printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+ bool withKeyword = false,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+ void printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false,
bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- if (attrs.empty())
- return;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ (void)printNamedAttrFn;
if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs)
printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs,
- bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+ bool withKeyword,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
// Otherwise, print them all out in braces.
os << " {";
- interleaveComma(filteredAttrs,
- [&](NamedAttribute attr) { printNamedAttribute(attr); });
+ interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+ printNamedAttribute(attr, printNamedAttrFn);
+ });
os << '}';
};
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
return;
os << " = ";
+ if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+ /// If we print via the `printNamedAttrFn` callback skip printing.
+ return;
+ }
printAttribute(attr.getValue());
}
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- Impl::printOptionalAttrDict(attrs, elidedAttrs);
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+ printNamedAttrFn);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
More information about the Mlir-commits
mailing list