[Mlir-commits] [mlir] [mlir] Add struct parsing and printing utilities (PR #133939)

Jorn Tuyls llvmlistbot at llvm.org
Thu Apr 3 09:01:05 PDT 2025


https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/133939

>From 8e20e50970648e821be662ee43862d2601fea5bc Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Tue, 1 Apr 2025 06:52:34 -0500
Subject: [PATCH] [mlir] Add struct parsing and printing utilities

---
 mlir/include/mlir/IR/OpImplementation.h       | 43 +++++++++++++
 mlir/lib/AsmParser/AsmParserImpl.h            | 45 ++++++++++++++
 mlir/lib/IR/AsmPrinter.cpp                    | 22 +++++++
 .../test/IR/custom-struct-attr-roundtrip.mlir | 62 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  9 +++
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 37 +++++++++++
 6 files changed, 218 insertions(+)
 create mode 100644 mlir/test/IR/custom-struct-attr-roundtrip.mlir

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 25c7d15eb8ed5..d77f4147744e1 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -238,6 +238,29 @@ class AsmPrinter {
 
   void printDimensionList(ArrayRef<int64_t> shape);
 
+  //===----------------------------------------------------------------------===//
+  // Struct Printing
+  //===----------------------------------------------------------------------===//
+
+  /// Print a comma-separated list of key-value pairs using the provided
+  /// `keywords` and corresponding printing functions. This performs similar
+  /// printing as the the assembly format's `struct` directive printer, but
+  /// allows bringing in custom printers for fields.
+  ///
+  /// Example:
+  /// <
+  ///   foo = foo_value,
+  ///   bar = bar_value,
+  ///   ...
+  /// >
+  virtual void
+  printStruct(ArrayRef<StringRef> keywords,
+              ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs);
+
+  //===----------------------------------------------------------------------===//
+  // Cyclic Printing
+  //===----------------------------------------------------------------------===//
+
   /// Class used to automatically end a cyclic region on destruction.
   class CyclicPrintReset {
   public:
@@ -1409,6 +1432,26 @@ class AsmParser {
     return CyclicParseReset(this);
   }
 
+  //===----------------------------------------------------------------------===//
+  // Struct Parsing
+  //===----------------------------------------------------------------------===//
+
+  /// Parse a comma-separated list of key-value pairs with a specified
+  /// delimiter. This performs similar parsing as the the assembly format
+  /// `struct` directive parser with custom delimiter and/or field parsing. The
+  /// variables are printed in the order they are specified in the argument list
+  /// but can be parsed in any order.
+  ///
+  /// Example:
+  /// <
+  ///   foo = something_parsed_by_a_custom_parser,
+  ///   bar = something_parsed_by_a_different_custom_parser,
+  ///   ...
+  /// >
+  virtual ParseResult
+  parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+              ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) = 0;
+
 protected:
   /// Parse a handle to a resource within the assembly format for the given
   /// dialect.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1f8fbfdd93568..cff3f5402dd79 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -570,6 +570,51 @@ class AsmParserImpl : public BaseT {
     parser.getState().cyclicParsingStack.pop_back();
   }
 
+  //===----------------------------------------------------------------------===//
+  // Struct Parsing
+  //===----------------------------------------------------------------------===//
+
+  /// Parse a comma-separated list of key-value pairs with a specified
+  /// delimiter.
+  ParseResult
+  parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+              ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) override {
+    assert(keywords.size() == parseFuncs.size());
+    auto keyError = [&]() -> ParseResult {
+      InFlightDiagnostic parseError =
+          emitError(getCurrentLocation(), "expected one of: ");
+      llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
+        parseError << '`' << kw << '`';
+      });
+      return parseError;
+    };
+    SmallVector<bool> seen(keywords.size(), false);
+    DenseMap<StringRef, size_t> keywordToIndex;
+    for (auto &&[idx, keyword] : llvm::enumerate(keywords))
+      keywordToIndex[keyword] = idx;
+    return parseCommaSeparatedList(
+        delimiter,
+        [&]() -> ParseResult {
+          StringRef keyword;
+          if (failed(parseOptionalKeyword(&keyword)))
+            return keyError();
+          if (!keywordToIndex.contains(keyword))
+            return keyError();
+          size_t idx = keywordToIndex[keyword];
+          if (seen[idx]) {
+            return emitError(getCurrentLocation(), "duplicated `")
+                   << keyword << "` entry";
+          }
+          if (failed(parseEqual()))
+            return failure();
+          if (failed(parseFuncs[idx]()))
+            return failure();
+          seen[idx] = true;
+          return success();
+        },
+        "parse struct");
+  }
+
   //===--------------------------------------------------------------------===//
   // Code Completion
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5b5ec841917e7..7814d8f2cab18 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3002,6 +3002,28 @@ void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
   detail::printDimensionList(getStream(), shape);
 }
 
+//===----------------------------------------------------------------------===//
+// Struct Printing
+//===----------------------------------------------------------------------===//
+
+/// Print a comma-separated list of key-value pairs.
+void AsmPrinter::printStruct(
+    ArrayRef<StringRef> keywords,
+    ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs) {
+  DenseMap<StringRef, llvm::function_ref<void(AsmPrinter & p)>> keywordToFunc;
+  for (auto &&[kw, printFunc] : llvm::zip(keywords, printFuncs))
+    keywordToFunc[kw] = printFunc;
+  auto &os = getStream();
+  llvm::interleaveComma(keywords, os, [&](StringRef kw) {
+    os << kw << " = ";
+    keywordToFunc[kw](*this);
+  });
+}
+
+//===----------------------------------------------------------------------===//
+// Cyclic Printing
+//===----------------------------------------------------------------------===//
+
 LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
   return impl->pushCyclicPrinting(opaquePointer);
 }
diff --git a/mlir/test/IR/custom-struct-attr-roundtrip.mlir b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
new file mode 100644
index 0000000000000..68f69f99b86a3
--- /dev/null
+++ b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_struct_attr_roundtrip
+func.func @test_struct_attr_roundtrip() -> () {
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+  "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+  "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+  "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+  "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
+  return
+}
+
+// -----
+
+// Verify all keywords must be provided. All missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `type_str` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `value` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
+
+// -----
+
+// Verify invalid keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
+
+// -----
+
+// Verify duplicated keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{duplicated `type_str` entry}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
+
+// -----
+
+// Verify equals missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected '='}}
+"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index fc2d77af29f12..2dae52ab7449c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -369,6 +369,15 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
   }];
 }
 
+// Test AsmParser::parseStruct and AsmPrinter::printStruct APIs through the custom
+// parser and printer.
+def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
+  let mnemonic = "custom_struct";
+  let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
+                        OptionalParameter<"mlir::ArrayAttr">:$opt_value);
+  let hasCustomAssemblyFormat = 1;
+}
+
 def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
   let mnemonic = "nested_polynomial";
   let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 057d9fb4a215f..89c7c527a2247 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -316,6 +316,43 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestCustomStructAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestCustomStructAttr::parse(AsmParser &p, Type type) {
+  std::string typeStr;
+  int64_t value;
+  FailureOr<ArrayAttr> optValue;
+  if (failed(p.parseStruct(AsmParser::Delimiter::LessGreater,
+                           {"type_str", "value", "opt_value"},
+                           {[&]() { return p.parseString(&typeStr); },
+                            [&]() { return p.parseInteger(value); },
+                            [&]() {
+                              optValue = mlir::FieldParser<ArrayAttr>::parse(p);
+                              return success(succeeded(optValue));
+                            }}))) {
+    p.emitError(p.getCurrentLocation())
+        << "failed parsing `TestCustomStructAttr`";
+    return {};
+  }
+  return get(p.getContext(), StringAttr::get(p.getContext(), typeStr), value,
+             optValue.value_or(ArrayAttr()));
+}
+
+void TestCustomStructAttr::print(AsmPrinter &p) const {
+  p << "<";
+  p.printStruct(
+      {"type_str", "value"},
+      {[&](AsmPrinter &p) { p.printStrippedAttrOrType(getTypeStr()); },
+       [&](AsmPrinter &p) { p.printStrippedAttrOrType(getValue()); }});
+  if (getOptValue() != ArrayAttr()) {
+    p << ", opt_value = ";
+    p.printStrippedAttrOrType(getOptValue());
+  }
+  p << ">";
+}
+
 //===----------------------------------------------------------------------===//
 // TestOpAsmAttrInterfaceAttr
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list