[Mlir-commits] [mlir] [mlir] Add struct parsing and printing utilities (PR #133939)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 3 09:01:40 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Jorn Tuyls (jtuyls)
<details>
<summary>Changes</summary>
This PR implements utilities to parse and print a comma-separated list of key-value pairs, similar to the `struct` directive in tablegen.
>From the docs:
> struct Directive
The struct directive accepts a list of variables to capture and will generate a parser and printer for a comma-separated list of key-value pairs. If an optional parameter is included in the struct, it can be elided. The variables are printed in the order they are specified in the argument list but can be parsed in any order.
This enables defining custom struct parsing and printing functions if the `struct` directive doesn't suffice. There is some existing potential downstream usage for it: https://github.com/openxla/stablehlo/blob/a3c7de92425e8035437dae67ab2318a82eca79a1/stablehlo/dialect/StablehloOps.cpp#L3102
---
Full diff: https://github.com/llvm/llvm-project/pull/133939.diff
6 Files Affected:
- (modified) mlir/include/mlir/IR/OpImplementation.h (+43)
- (modified) mlir/lib/AsmParser/AsmParserImpl.h (+45)
- (modified) mlir/lib/IR/AsmPrinter.cpp (+22)
- (added) mlir/test/IR/custom-struct-attr-roundtrip.mlir (+62)
- (modified) mlir/test/lib/Dialect/Test/TestAttrDefs.td (+9)
- (modified) mlir/test/lib/Dialect/Test/TestAttributes.cpp (+37)
``````````diff
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 25c7d15eb8ed5..d77f4147744e1 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -238,6 +238,29 @@ class AsmPrinter {
void printDimensionList(ArrayRef<int64_t> shape);
+ //===----------------------------------------------------------------------===//
+ // Struct Printing
+ //===----------------------------------------------------------------------===//
+
+ /// Print a comma-separated list of key-value pairs using the provided
+ /// `keywords` and corresponding printing functions. This performs similar
+ /// printing as the the assembly format's `struct` directive printer, but
+ /// allows bringing in custom printers for fields.
+ ///
+ /// Example:
+ /// <
+ /// foo = foo_value,
+ /// bar = bar_value,
+ /// ...
+ /// >
+ virtual void
+ printStruct(ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs);
+
+ //===----------------------------------------------------------------------===//
+ // Cyclic Printing
+ //===----------------------------------------------------------------------===//
+
/// Class used to automatically end a cyclic region on destruction.
class CyclicPrintReset {
public:
@@ -1409,6 +1432,26 @@ class AsmParser {
return CyclicParseReset(this);
}
+ //===----------------------------------------------------------------------===//
+ // Struct Parsing
+ //===----------------------------------------------------------------------===//
+
+ /// Parse a comma-separated list of key-value pairs with a specified
+ /// delimiter. This performs similar parsing as the the assembly format
+ /// `struct` directive parser with custom delimiter and/or field parsing. The
+ /// variables are printed in the order they are specified in the argument list
+ /// but can be parsed in any order.
+ ///
+ /// Example:
+ /// <
+ /// foo = something_parsed_by_a_custom_parser,
+ /// bar = something_parsed_by_a_different_custom_parser,
+ /// ...
+ /// >
+ virtual ParseResult
+ parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) = 0;
+
protected:
/// Parse a handle to a resource within the assembly format for the given
/// dialect.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1f8fbfdd93568..cff3f5402dd79 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -570,6 +570,51 @@ class AsmParserImpl : public BaseT {
parser.getState().cyclicParsingStack.pop_back();
}
+ //===----------------------------------------------------------------------===//
+ // Struct Parsing
+ //===----------------------------------------------------------------------===//
+
+ /// Parse a comma-separated list of key-value pairs with a specified
+ /// delimiter.
+ ParseResult
+ parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) override {
+ assert(keywords.size() == parseFuncs.size());
+ auto keyError = [&]() -> ParseResult {
+ InFlightDiagnostic parseError =
+ emitError(getCurrentLocation(), "expected one of: ");
+ llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
+ parseError << '`' << kw << '`';
+ });
+ return parseError;
+ };
+ SmallVector<bool> seen(keywords.size(), false);
+ DenseMap<StringRef, size_t> keywordToIndex;
+ for (auto &&[idx, keyword] : llvm::enumerate(keywords))
+ keywordToIndex[keyword] = idx;
+ return parseCommaSeparatedList(
+ delimiter,
+ [&]() -> ParseResult {
+ StringRef keyword;
+ if (failed(parseOptionalKeyword(&keyword)))
+ return keyError();
+ if (!keywordToIndex.contains(keyword))
+ return keyError();
+ size_t idx = keywordToIndex[keyword];
+ if (seen[idx]) {
+ return emitError(getCurrentLocation(), "duplicated `")
+ << keyword << "` entry";
+ }
+ if (failed(parseEqual()))
+ return failure();
+ if (failed(parseFuncs[idx]()))
+ return failure();
+ seen[idx] = true;
+ return success();
+ },
+ "parse struct");
+ }
+
//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5b5ec841917e7..7814d8f2cab18 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3002,6 +3002,28 @@ void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
detail::printDimensionList(getStream(), shape);
}
+//===----------------------------------------------------------------------===//
+// Struct Printing
+//===----------------------------------------------------------------------===//
+
+/// Print a comma-separated list of key-value pairs.
+void AsmPrinter::printStruct(
+ ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs) {
+ DenseMap<StringRef, llvm::function_ref<void(AsmPrinter & p)>> keywordToFunc;
+ for (auto &&[kw, printFunc] : llvm::zip(keywords, printFuncs))
+ keywordToFunc[kw] = printFunc;
+ auto &os = getStream();
+ llvm::interleaveComma(keywords, os, [&](StringRef kw) {
+ os << kw << " = ";
+ keywordToFunc[kw](*this);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// Cyclic Printing
+//===----------------------------------------------------------------------===//
+
LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
return impl->pushCyclicPrinting(opaquePointer);
}
diff --git a/mlir/test/IR/custom-struct-attr-roundtrip.mlir b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
new file mode 100644
index 0000000000000..68f69f99b86a3
--- /dev/null
+++ b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_struct_attr_roundtrip
+func.func @test_struct_attr_roundtrip() -> () {
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
+ return
+}
+
+// -----
+
+// Verify all keywords must be provided. All missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `type_str` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `value` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
+
+// -----
+
+// Verify invalid keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
+
+// -----
+
+// Verify duplicated keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{duplicated `type_str` entry}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
+
+// -----
+
+// Verify equals missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected '='}}
+"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index fc2d77af29f12..2dae52ab7449c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -369,6 +369,15 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
}];
}
+// Test AsmParser::parseStruct and AsmPrinter::printStruct APIs through the custom
+// parser and printer.
+def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
+ let mnemonic = "custom_struct";
+ let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
+ OptionalParameter<"mlir::ArrayAttr">:$opt_value);
+ let hasCustomAssemblyFormat = 1;
+}
+
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 057d9fb4a215f..89c7c527a2247 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -316,6 +316,43 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
return success();
}
+//===----------------------------------------------------------------------===//
+// TestCustomStructAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestCustomStructAttr::parse(AsmParser &p, Type type) {
+ std::string typeStr;
+ int64_t value;
+ FailureOr<ArrayAttr> optValue;
+ if (failed(p.parseStruct(AsmParser::Delimiter::LessGreater,
+ {"type_str", "value", "opt_value"},
+ {[&]() { return p.parseString(&typeStr); },
+ [&]() { return p.parseInteger(value); },
+ [&]() {
+ optValue = mlir::FieldParser<ArrayAttr>::parse(p);
+ return success(succeeded(optValue));
+ }}))) {
+ p.emitError(p.getCurrentLocation())
+ << "failed parsing `TestCustomStructAttr`";
+ return {};
+ }
+ return get(p.getContext(), StringAttr::get(p.getContext(), typeStr), value,
+ optValue.value_or(ArrayAttr()));
+}
+
+void TestCustomStructAttr::print(AsmPrinter &p) const {
+ p << "<";
+ p.printStruct(
+ {"type_str", "value"},
+ {[&](AsmPrinter &p) { p.printStrippedAttrOrType(getTypeStr()); },
+ [&](AsmPrinter &p) { p.printStrippedAttrOrType(getValue()); }});
+ if (getOptValue() != ArrayAttr()) {
+ p << ", opt_value = ";
+ p.printStrippedAttrOrType(getOptValue());
+ }
+ p << ">";
+}
+
//===----------------------------------------------------------------------===//
// TestOpAsmAttrInterfaceAttr
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/133939
More information about the Mlir-commits
mailing list