[Mlir-commits] [mlir] [mlir] Add struct parsing and printing utilities (PR #133939)
Jorn Tuyls
llvmlistbot at llvm.org
Thu Apr 3 09:01:05 PDT 2025
https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/133939
>From 8e20e50970648e821be662ee43862d2601fea5bc Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Tue, 1 Apr 2025 06:52:34 -0500
Subject: [PATCH] [mlir] Add struct parsing and printing utilities
---
mlir/include/mlir/IR/OpImplementation.h | 43 +++++++++++++
mlir/lib/AsmParser/AsmParserImpl.h | 45 ++++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 22 +++++++
.../test/IR/custom-struct-attr-roundtrip.mlir | 62 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 9 +++
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 37 +++++++++++
6 files changed, 218 insertions(+)
create mode 100644 mlir/test/IR/custom-struct-attr-roundtrip.mlir
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 25c7d15eb8ed5..d77f4147744e1 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -238,6 +238,29 @@ class AsmPrinter {
void printDimensionList(ArrayRef<int64_t> shape);
+ //===----------------------------------------------------------------------===//
+ // Struct Printing
+ //===----------------------------------------------------------------------===//
+
+ /// Print a comma-separated list of key-value pairs using the provided
+ /// `keywords` and corresponding printing functions. This performs similar
+ /// printing as the the assembly format's `struct` directive printer, but
+ /// allows bringing in custom printers for fields.
+ ///
+ /// Example:
+ /// <
+ /// foo = foo_value,
+ /// bar = bar_value,
+ /// ...
+ /// >
+ virtual void
+ printStruct(ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs);
+
+ //===----------------------------------------------------------------------===//
+ // Cyclic Printing
+ //===----------------------------------------------------------------------===//
+
/// Class used to automatically end a cyclic region on destruction.
class CyclicPrintReset {
public:
@@ -1409,6 +1432,26 @@ class AsmParser {
return CyclicParseReset(this);
}
+ //===----------------------------------------------------------------------===//
+ // Struct Parsing
+ //===----------------------------------------------------------------------===//
+
+ /// Parse a comma-separated list of key-value pairs with a specified
+ /// delimiter. This performs similar parsing as the the assembly format
+ /// `struct` directive parser with custom delimiter and/or field parsing. The
+ /// variables are printed in the order they are specified in the argument list
+ /// but can be parsed in any order.
+ ///
+ /// Example:
+ /// <
+ /// foo = something_parsed_by_a_custom_parser,
+ /// bar = something_parsed_by_a_different_custom_parser,
+ /// ...
+ /// >
+ virtual ParseResult
+ parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) = 0;
+
protected:
/// Parse a handle to a resource within the assembly format for the given
/// dialect.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1f8fbfdd93568..cff3f5402dd79 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -570,6 +570,51 @@ class AsmParserImpl : public BaseT {
parser.getState().cyclicParsingStack.pop_back();
}
+ //===----------------------------------------------------------------------===//
+ // Struct Parsing
+ //===----------------------------------------------------------------------===//
+
+ /// Parse a comma-separated list of key-value pairs with a specified
+ /// delimiter.
+ ParseResult
+ parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) override {
+ assert(keywords.size() == parseFuncs.size());
+ auto keyError = [&]() -> ParseResult {
+ InFlightDiagnostic parseError =
+ emitError(getCurrentLocation(), "expected one of: ");
+ llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
+ parseError << '`' << kw << '`';
+ });
+ return parseError;
+ };
+ SmallVector<bool> seen(keywords.size(), false);
+ DenseMap<StringRef, size_t> keywordToIndex;
+ for (auto &&[idx, keyword] : llvm::enumerate(keywords))
+ keywordToIndex[keyword] = idx;
+ return parseCommaSeparatedList(
+ delimiter,
+ [&]() -> ParseResult {
+ StringRef keyword;
+ if (failed(parseOptionalKeyword(&keyword)))
+ return keyError();
+ if (!keywordToIndex.contains(keyword))
+ return keyError();
+ size_t idx = keywordToIndex[keyword];
+ if (seen[idx]) {
+ return emitError(getCurrentLocation(), "duplicated `")
+ << keyword << "` entry";
+ }
+ if (failed(parseEqual()))
+ return failure();
+ if (failed(parseFuncs[idx]()))
+ return failure();
+ seen[idx] = true;
+ return success();
+ },
+ "parse struct");
+ }
+
//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5b5ec841917e7..7814d8f2cab18 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3002,6 +3002,28 @@ void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
detail::printDimensionList(getStream(), shape);
}
+//===----------------------------------------------------------------------===//
+// Struct Printing
+//===----------------------------------------------------------------------===//
+
+/// Print a comma-separated list of key-value pairs.
+void AsmPrinter::printStruct(
+ ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs) {
+ DenseMap<StringRef, llvm::function_ref<void(AsmPrinter & p)>> keywordToFunc;
+ for (auto &&[kw, printFunc] : llvm::zip(keywords, printFuncs))
+ keywordToFunc[kw] = printFunc;
+ auto &os = getStream();
+ llvm::interleaveComma(keywords, os, [&](StringRef kw) {
+ os << kw << " = ";
+ keywordToFunc[kw](*this);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// Cyclic Printing
+//===----------------------------------------------------------------------===//
+
LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
return impl->pushCyclicPrinting(opaquePointer);
}
diff --git a/mlir/test/IR/custom-struct-attr-roundtrip.mlir b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
new file mode 100644
index 0000000000000..68f69f99b86a3
--- /dev/null
+++ b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_struct_attr_roundtrip
+func.func @test_struct_attr_roundtrip() -> () {
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
+ return
+}
+
+// -----
+
+// Verify all keywords must be provided. All missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `type_str` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `value` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
+
+// -----
+
+// Verify invalid keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
+
+// -----
+
+// Verify duplicated keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{duplicated `type_str` entry}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
+
+// -----
+
+// Verify equals missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected '='}}
+"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index fc2d77af29f12..2dae52ab7449c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -369,6 +369,15 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
}];
}
+// Test AsmParser::parseStruct and AsmPrinter::printStruct APIs through the custom
+// parser and printer.
+def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
+ let mnemonic = "custom_struct";
+ let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
+ OptionalParameter<"mlir::ArrayAttr">:$opt_value);
+ let hasCustomAssemblyFormat = 1;
+}
+
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 057d9fb4a215f..89c7c527a2247 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -316,6 +316,43 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
return success();
}
+//===----------------------------------------------------------------------===//
+// TestCustomStructAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestCustomStructAttr::parse(AsmParser &p, Type type) {
+ std::string typeStr;
+ int64_t value;
+ FailureOr<ArrayAttr> optValue;
+ if (failed(p.parseStruct(AsmParser::Delimiter::LessGreater,
+ {"type_str", "value", "opt_value"},
+ {[&]() { return p.parseString(&typeStr); },
+ [&]() { return p.parseInteger(value); },
+ [&]() {
+ optValue = mlir::FieldParser<ArrayAttr>::parse(p);
+ return success(succeeded(optValue));
+ }}))) {
+ p.emitError(p.getCurrentLocation())
+ << "failed parsing `TestCustomStructAttr`";
+ return {};
+ }
+ return get(p.getContext(), StringAttr::get(p.getContext(), typeStr), value,
+ optValue.value_or(ArrayAttr()));
+}
+
+void TestCustomStructAttr::print(AsmPrinter &p) const {
+ p << "<";
+ p.printStruct(
+ {"type_str", "value"},
+ {[&](AsmPrinter &p) { p.printStrippedAttrOrType(getTypeStr()); },
+ [&](AsmPrinter &p) { p.printStrippedAttrOrType(getValue()); }});
+ if (getOptValue() != ArrayAttr()) {
+ p << ", opt_value = ";
+ p.printStrippedAttrOrType(getOptValue());
+ }
+ p << ">";
+}
+
//===----------------------------------------------------------------------===//
// TestOpAsmAttrInterfaceAttr
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list