[Mlir-commits] [mlir] [mlir][tblgen] Add custom parsing and printing within struct (PR #133939)

Jorn Tuyls llvmlistbot at llvm.org
Wed Apr 9 05:27:30 PDT 2025


https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/133939

>From 547e013352461afae67085173fd3fbecf8e63a1d Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Tue, 1 Apr 2025 06:52:34 -0500
Subject: [PATCH] [mlir][tblgen] Add `custom` parsing and printing within
 `struct`

---
 .../test/IR/custom-struct-attr-roundtrip.mlir |  66 +++++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  10 +
 mlir/test/lib/Dialect/Test/TestAttributes.cpp |  21 ++
 .../attr-or-type-format-invalid.td            |  23 +-
 mlir/test/mlir-tblgen/attr-or-type-format.td  |  78 ++++++
 .../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 249 +++++++++++-------
 mlir/tools/mlir-tblgen/FormatGen.cpp          |   6 +-
 mlir/tools/mlir-tblgen/FormatGen.h            |  43 ++-
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        |  16 +-
 9 files changed, 401 insertions(+), 111 deletions(-)
 create mode 100644 mlir/test/IR/custom-struct-attr-roundtrip.mlir

diff --git a/mlir/test/IR/custom-struct-attr-roundtrip.mlir b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
new file mode 100644
index 0000000000000..4d07e896f5b1d
--- /dev/null
+++ b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_struct_attr_roundtrip
+func.func @test_struct_attr_roundtrip() -> () {
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+  "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+  "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+  "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
+  // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+  "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
+  return
+}
+
+// -----
+
+// Verify all required parameters must be provided. `value` is missing.
+
+// expected-error @below {{struct is missing required parameter: value}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct">} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. All missing.
+
+// expected-error @below {{expected valid keyword}}
+// expected-error @below {{expected a parameter name in struct}}
+"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `type_str` missing.
+
+// expected-error @below {{expected valid keyword}}
+// expected-error @below {{expected a parameter name in struct}}
+"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `value` missing.
+
+// expected-error @below {{expected valid keyword}}
+// expected-error @below {{expected a parameter name in struct}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
+
+// -----
+
+// Verify invalid keyword provided.
+
+// expected-error @below {{duplicate or unknown struct parameter name: type_str2}}
+"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
+
+// -----
+
+// Verify duplicated keyword provided.
+
+// expected-error @below {{duplicate or unknown struct parameter name: type_str}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
+
+// -----
+
+// Verify equals missing.
+
+// expected-error @below {{expected '='}}
+"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index fc2d77af29f12..6441a82d87eba 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
   }];
 }
 
+// Test `struct` with nested `custom` assembly format.
+def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
+  let mnemonic = "custom_struct";
+  let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
+                        OptionalParameter<"mlir::ArrayAttr">:$opt_value);
+  let assemblyFormat = [{
+    `<` struct($type_str, custom<CustomStructAttr>($value), custom<CustomOptStructFieldAttr>($opt_value)) `>`
+  }];
+}
+
 def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
   let mnemonic = "nested_polynomial";
   let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 057d9fb4a215f..988c6ecc2afb6 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -316,6 +316,27 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestCustomStructAttr
+//===----------------------------------------------------------------------===//
+
+static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
+  p.printStrippedAttrOrType(value);
+}
+
+static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
+  return p.parseInteger(value);
+}
+
+static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
+  p.printStrippedAttrOrType(attr);
+}
+
+static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
+                                                 ArrayAttr &attr) {
+  return p.parseAttribute(attr);
+}
+
 //===----------------------------------------------------------------------===//
 // TestOpAsmAttrInterfaceAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index 3a57cbca4d7bb..6cb9e8b4bb45e 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -37,7 +37,7 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
 def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
   let parameters = (ins "int":$v0);
   // CHECK: literals may only be used in the top-level section of the format
-  // CHECK: expected a variable in `struct` argument list
+  // CHECK: expected a variable or `custom` directive in `struct` argument list
   let assemblyFormat = "`<` struct($v0, `,`) `>`";
 }
 
@@ -144,3 +144,24 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> {
   // CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor
   let assemblyFormat = "$a (`(` custom<Foo>(ref($a))^ `)`)?";
 }
+
+// Test `struct` with nested `custom` directive with multiple fields.
+def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  // CHECK: `struct` can only contain `custom` directives with a single argument
+  let assemblyFormat = "struct(custom<Foo>($a, $b))";
+}
+
+// Test `struct` with nested `custom` directive invalid parameter.
+def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> {
+  let parameters = (ins OptionalParameter<"int">:$a);
+  // CHECK: a `custom` directive nested within a `struct` must be passed a parameter
+  let assemblyFormat = "struct($a, custom<Foo>(ref($a)))";
+}
+
+// Test `custom` with nested `custom` directive invalid parameter.
+def InvalidTypeW : InvalidType<"InvalidTypeV", "invalid_v"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  // CHECK: `custom` can only be used at the top-level context or within a `struct` directive
+  let assemblyFormat = "custom<Foo>($a, custom<Bar>($b))";
+}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index c5348409e8e44..0f6b0c401a4e6 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -736,6 +736,84 @@ def TypeS : TestType<"TestS"> {
   let assemblyFormat = "$a";
 }
 
+/// Test that a `struct` with nested `custom` parser and printer are generated correctly.
+
+// ATTR:    ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR:                                       ::mlir::Type odsType) {
+// ATTR:      bool _seen_v0 = false;
+// ATTR:      bool _seen_v1 = false;
+// ATTR:      bool _seen_v2 = false;
+// ATTR:      const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// ATTR:        if (odsParser.parseEqual())
+// ATTR:          return {};
+// ATTR:        if (!_seen_v0 && _paramKey == "v0") {
+// ATTR:          _seen_v0 = true;
+// ATTR:          _result_v0 = ::parseAttrParamA(odsParser, odsType);
+// ATTR:          if (::mlir::failed(_result_v0))
+// ATTR:            return {};
+// ATTR:        } else if (!_seen_v1 && _paramKey == "v1") {
+// ATTR:          _seen_v1 = true;
+// ATTR:          {
+// ATTR:            auto odsCustomResult = parseNestedCustom(odsParser,
+// ATTR-NEXT:         ::mlir::detail::unwrapForCustomParse(_result_v1));
+// ATTR:            if (::mlir::failed(odsCustomResult)) return {};
+// ATTR:            if (::mlir::failed(_result_v1)) {
+// ATTR:              odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'");
+// ATTR:              return {};
+// ATTR:            }
+// ATTR:          }
+// ATTR:        } else if (!_seen_v2 && _paramKey == "v2") {
+// ATTR:          _seen_v2 = true;
+// ATTR:          _result_v2 = ::mlir::FieldParser<AttrParamB>::parse(odsParser);
+// ATTR:          if (::mlir::failed(_result_v2)) {
+// ATTR:            odsParser.emitError(odsParser.getCurrentLocation(), "failed to parse AttrT parameter 'v2' which is to be a `AttrParamB`");
+// ATTR:            return {};
+// ATTR:          }
+// ATTR:        } else {
+// ATTR:          return {};
+// ATTR:        }
+// ATTR:        return true;
+// ATTR:      }
+// ATTR:      do {
+// ATTR:        ::llvm::StringRef _paramKey;
+// ATTR:        if (odsParser.parseKeyword(&_paramKey)) {
+// ATTR:          odsParser.emitError(odsParser.getCurrentLocation(),
+// ATTR-NEXT:                         "expected a parameter name in struct");
+// ATTR:          return {};
+// ATTR:        }
+// ATTR:        if (!_loop_body(_paramKey)) return {};
+// ATTR:      } while(!odsParser.parseOptionalComma());
+// ATTR:      if (!_seen_v0)
+// ATTR:      if (!_seen_v1)
+// ATTR:      return TestTAttr::get(odsParser.getContext(),
+// ATTR:          TestParamA((*_result_v0)),
+// ATTR:          TestParamB((*_result_v1)),
+// ATTR:           AttrParamB((_result_v2.value_or(AttrParamB()))));
+// ATTR:    }
+
+// ATTR:      void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const {
+// ATTR:        odsPrinter << "v0 = ";
+// ATTR:        ::printAttrParamA(odsPrinter, getV0());
+// ATTR:        odsPrinter << ", ";
+// ATTR:        odsPrinter << "v1 = ";
+// ATTR:        printNestedCustom(odsPrinter,
+// ATTR-NEXT:     getV1());
+// ATTR:        if (!(getV2() == AttrParamB())) {
+// ATTR:          odsPrinter << "v2 = ";
+// ATTR:          odsPrinter.printStrippedAttrOrType(getV2());
+// ATTR:      }
+
+def AttrT : TestAttr<"TestT"> {
+  let parameters = (ins
+      AttrParamA:$v0,
+      AttrParamB:$v1,
+      OptionalParameter<"AttrParamB">:$v2
+  );
+
+  let mnemonic = "attr_t";
+  let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
+}
+
 // DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
 // DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
 // DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index a4ae271edb6bd..675dc5e8e551f 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -13,6 +13,7 @@
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -75,37 +76,33 @@ class ParameterElement
   AttrOrTypeParameter param;
 };
 
+/// Utility to return a parameter element for the provided `struct` format
+/// element. This parameter can originate from either a `ParameterElement` or a
+/// `CustomDirective` with a single parameter argument.
+static ParameterElement *getStructParameterElement(FormatElement *el) {
+  return TypeSwitch<FormatElement *, ParameterElement *>(el)
+      .Case<ParameterElement>([&](auto param) { return param; })
+      .Case<CustomDirective>([&](auto custom) {
+        FailureOr<ParameterElement *> maybeParam =
+            custom->template getFrontAs<ParameterElement>();
+        return *maybeParam;
+      })
+      .Default([&](auto el) {
+        assert(false && "unexpected struct element type");
+        return nullptr;
+      });
+}
+
 /// Shorthand functions that can be used with ranged-based conditions.
 static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
+static bool formatIsOptional(FormatElement *el) {
+  ParameterElement *param = getStructParameterElement(el);
+  return param != nullptr && param->isOptional();
+}
 static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
-
-/// Base class for a directive that contains references to multiple variables.
-template <DirectiveElement::Kind DirectiveKind>
-class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
-public:
-  using Base = ParamsDirectiveBase<DirectiveKind>;
-
-  ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
-      : params(std::move(params)) {}
-
-  /// Get the parameters contained in this directive.
-  ArrayRef<ParameterElement *> getParams() const { return params; }
-
-  /// Get the number of parameters.
-  unsigned getNumParams() const { return params.size(); }
-
-  /// Take all of the parameters from this directive.
-  std::vector<ParameterElement *> takeParams() { return std::move(params); }
-
-  /// Returns true if there are optional parameters present.
-  bool hasOptionalParams() const {
-    return llvm::any_of(getParams(), paramIsOptional);
-  }
-
-private:
-  /// The parameters captured by this directive.
-  std::vector<ParameterElement *> params;
-};
+static bool formatNotOptional(FormatElement *el) {
+  return !formatIsOptional(el);
+}
 
 /// This class represents a `params` directive that refers to all parameters
 /// of an attribute or type. When used as a top-level directive, it generates
@@ -116,9 +113,15 @@ class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
 /// When used as an argument to another directive that accepts variables,
 /// `params` can be used in place of manually listing all parameters of an
 /// attribute or type.
-class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
+class ParamsDirective
+    : public VectorDirectiveBase<DirectiveElement::Params, ParameterElement *> {
 public:
   using Base::Base;
+
+  /// Returns true if there are optional parameters present.
+  bool hasOptionalElements() const {
+    return llvm::any_of(getElements(), paramIsOptional);
+  }
 };
 
 /// This class represents a `struct` directive that generates a struct format
@@ -126,9 +129,15 @@ class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
 ///
 ///   `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
 ///
-class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
+class StructDirective
+    : public VectorDirectiveBase<DirectiveElement::Struct, FormatElement *> {
 public:
   using Base::Base;
+
+  /// Returns true if there are optional format elements present.
+  bool hasOptionalElements() const {
+    return llvm::any_of(getElements(), formatIsOptional);
+  }
 };
 
 } // namespace
@@ -214,10 +223,10 @@ class DefFormat {
   /// Generate the printer code for a variable.
   void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
                           bool skipGuard = false);
-  /// Generate a printer for comma-separated parameters.
-  void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
+  /// Generate a printer for comma-separated format elements.
+  void genCommaSeparatedPrinter(ArrayRef<FormatElement *> params,
                                 FmtContext &ctx, MethodBody &os,
-                                function_ref<void(ParameterElement *)> extra);
+                                function_ref<void(FormatElement *)> extra);
   /// Generate the printer code for a `params` directive.
   void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the printer code for a `struct` directive.
@@ -443,14 +452,14 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
 
   // If there are optional parameters, we need to switch to `parseOptionalComma`
   // if there are no more required parameters after a certain point.
-  bool hasOptional = el->hasOptionalParams();
+  bool hasOptional = el->hasOptionalElements();
   if (hasOptional) {
     // Wrap everything in a do-while so that we can `break`.
     os << "do {\n";
     os.indent();
   }
 
-  ArrayRef<ParameterElement *> params = el->getParams();
+  ArrayRef<ParameterElement *> params = el->getElements();
   using IteratorT = ParameterElement *const *;
   IteratorT it = params.begin();
 
@@ -551,22 +560,31 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
     while (!$_parser.parseOptionalComma()) {
 )";
 
+  const char *const checkParamKey = R"(
+  if (!_seen_{0} && _paramKey == "{0}") {
+    _seen_{0} = true;
+)";
+
   os << "// Parse parameter struct\n";
 
   // Declare a "seen" variable for each key.
-  for (ParameterElement *param : el->getParams())
+  for (FormatElement *arg : el->getElements()) {
+    ParameterElement *param = getStructParameterElement(arg);
     os << formatv("bool _seen_{0} = false;\n", param->getName());
+  }
 
   // Generate the body of the parsing loop inside a lambda.
   os << "{\n";
   os.indent()
       << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
   genLiteralParser("=", ctx, os.indent());
-  for (ParameterElement *param : el->getParams()) {
-    os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
-                  "  _seen_{0} = true;\n",
-                  param->getName());
-    genVariableParser(param, ctx, os.indent());
+  for (FormatElement *arg : el->getElements()) {
+    ParameterElement *param = getStructParameterElement(arg);
+    os.getStream().printReindented(strfmt(checkParamKey, param->getName()));
+    if (auto realParam = dyn_cast<ParameterElement>(arg))
+      genVariableParser(param, ctx, os.indent());
+    else if (auto custom = dyn_cast<CustomDirective>(arg))
+      genCustomParser(custom, ctx, os.indent());
     os.unindent() << "} else ";
     // Print the check for duplicate or unknown parameter.
   }
@@ -576,10 +594,10 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
 
   // Generate the parsing loop. If optional parameters are present, then the
   // parse loop is guarded by commas.
-  unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
+  unsigned numOptional = llvm::count_if(el->getElements(), formatIsOptional);
   if (numOptional) {
     // If the struct itself is optional, pull out the first iteration.
-    if (numOptional == el->getNumParams()) {
+    if (numOptional == el->getNumElements()) {
       os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
       os.indent();
     } else {
@@ -587,7 +605,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
     }
   } else {
     os.getStream().printReindented(
-        tgfmt(loopHeader, &ctx, el->getNumParams()).str());
+        tgfmt(loopHeader, &ctx, el->getNumElements()).str());
   }
   os.indent();
   os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
@@ -597,12 +615,13 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
   // all mandatory parameters have been parsed.
   // The whole struct is optional if all its parameters are optional.
   if (numOptional) {
-    if (numOptional == el->getNumParams()) {
+    if (numOptional == el->getNumElements()) {
       os << "}\n";
       os.unindent() << "}\n";
     } else {
       os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
-      for (ParameterElement *param : el->getParams()) {
+      for (FormatElement *arg : el->getElements()) {
+        ParameterElement *param = getStructParameterElement(arg);
         if (param->isOptional())
           continue;
         os.getStream().printReindented(
@@ -614,7 +633,8 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
     // N flags, successfully exiting the loop means that all parameters have
     // been seen. `parseOptionalComma` would cause issues with any formats that
     // use "struct(...) `,`" beacuse structs aren't sounded by braces.
-    os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
+    os.getStream().printReindented(
+        strfmt(loopTerminator, el->getNumElements()));
   }
   os.unindent() << "}\n";
 }
@@ -631,7 +651,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
   os << "(void)odsCustomLoc;\n";
   os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
   os.indent();
-  for (FormatElement *arg : el->getArguments()) {
+  for (FormatElement *arg : el->getElements()) {
     os << ",\n";
     if (auto *param = dyn_cast<ParameterElement>(arg))
       os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
@@ -648,7 +668,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
   } else {
     os << "if (::mlir::failed(odsCustomResult)) return {};\n";
   }
-  for (FormatElement *arg : el->getArguments()) {
+  for (FormatElement *arg : el->getElements()) {
     if (auto *param = dyn_cast<ParameterElement>(arg)) {
       if (param->isOptional())
         continue;
@@ -689,7 +709,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
     guardOn(llvm::ArrayRef(param));
   } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
     genParamsParser(params, ctx, os);
-    guardOn(params->getParams());
+    guardOn(params->getElements());
   } else if (auto *custom = dyn_cast<CustomDirective>(first)) {
     os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
     os.indent();
@@ -704,7 +724,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
   } else {
     auto *strct = cast<StructDirective>(first);
     genStructParser(strct, ctx, os);
-    guardOn(params->getParams());
+    guardOn(params->getElements());
   }
   os.indent();
 
@@ -816,14 +836,26 @@ static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
   os.indent();
 }
 
+/// Generate code to guard printing on the presence of any optional format
+/// elements.
+template <typename FormatElemRange>
+static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
+                               FormatElemRange &&args, bool inverted = false) {
+  guardOnAny(
+      ctx, os,
+      llvm::make_filter_range(llvm::map_range(args, getStructParameterElement),
+                              [](ParameterElement *param) { return !!param; }),
+      inverted);
+}
+
 void DefFormat::genCommaSeparatedPrinter(
-    ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
-    function_ref<void(ParameterElement *)> extra) {
+    ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
+    function_ref<void(FormatElement *)> extra) {
   // Emit a space if necessary, but only if the struct is present.
   if (shouldEmitSpace || !lastWasPunctuation) {
-    bool allOptional = llvm::all_of(params, paramIsOptional);
+    bool allOptional = llvm::all_of(args, formatIsOptional);
     if (allOptional)
-      guardOnAny(ctx, os, params);
+      guardOnAnyOptional(ctx, os, args);
     os << tgfmt("$_printer << ' ';\n", &ctx);
     if (allOptional)
       os.unindent() << "}\n";
@@ -832,17 +864,21 @@ void DefFormat::genCommaSeparatedPrinter(
   // The first printed element does not need to emit a comma.
   os << "{\n";
   os.indent() << "bool _firstPrinted = true;\n";
-  for (ParameterElement *param : params) {
+  for (FormatElement *arg : args) {
+    ParameterElement *param = getStructParameterElement(arg);
     if (param->isOptional()) {
       param->genPrintGuard(ctx, os << "if (") << ") {\n";
       os.indent();
     }
     os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
     os << "_firstPrinted = false;\n";
-    extra(param);
+    extra(arg);
     shouldEmitSpace = false;
     lastWasPunctuation = true;
-    genVariablePrinter(param, ctx, os);
+    if (auto realParam = dyn_cast<ParameterElement>(arg))
+      genVariablePrinter(realParam, ctx, os);
+    else if (auto custom = dyn_cast<CustomDirective>(arg))
+      genCustomPrinter(custom, ctx, os);
     if (param->isOptional())
       os.unindent() << "}\n";
   }
@@ -851,16 +887,19 @@ void DefFormat::genCommaSeparatedPrinter(
 
 void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
                                  MethodBody &os) {
-  genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
-                           [&](ParameterElement *param) {});
+  SmallVector<FormatElement *> args = llvm::map_to_vector(
+      el->getElements(), [](ParameterElement *param) -> FormatElement * {
+        return static_cast<FormatElement *>(param);
+      });
+  genCommaSeparatedPrinter(args, ctx, os, [&](FormatElement *param) {});
 }
 
 void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
                                  MethodBody &os) {
-  genCommaSeparatedPrinter(
-      llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
-        os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
-      });
+  genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) {
+    ParameterElement *param = getStructParameterElement(arg);
+    os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
+  });
 }
 
 void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
@@ -873,7 +912,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
 
   os << tgfmt("print$0($_printer", &ctx, el->getName());
   os.indent();
-  for (FormatElement *arg : el->getArguments()) {
+  for (FormatElement *arg : el->getElements()) {
     os << ",\n";
     if (auto *param = dyn_cast<ParameterElement>(arg)) {
       os << param->getParam().getAccessorName() << "()";
@@ -893,19 +932,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
   if (auto *param = dyn_cast<ParameterElement>(anchor)) {
     guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
   } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
-    guardOnAny(ctx, os, params->getParams(), el->isInverted());
+    guardOnAny(ctx, os, params->getElements(), el->isInverted());
   } else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
-    guardOnAny(ctx, os, strct->getParams(), el->isInverted());
+    guardOnAnyOptional(ctx, os, strct->getElements(), el->isInverted());
   } else {
     auto *custom = cast<CustomDirective>(anchor);
-    guardOnAny(ctx, os,
-               llvm::make_filter_range(
-                   llvm::map_range(custom->getArguments(),
-                                   [](FormatElement *el) {
-                                     return dyn_cast<ParameterElement>(el);
-                                   }),
-                   [](ParameterElement *param) { return !!param; }),
-               el->isInverted());
+    guardOnAnyOptional(ctx, os, custom->getElements(), el->isInverted());
   }
   // Generate the printer for the contained elements.
   {
@@ -960,6 +992,9 @@ class DefFormatParser : public FormatParser {
   LogicalResult verifyOptionalGroupElements(SMLoc loc,
                                             ArrayRef<FormatElement *> elements,
                                             FormatElement *anchor) override;
+  /// Verify the arguments to a struct directive.
+  LogicalResult verifyStructArguments(SMLoc loc,
+                                      ArrayRef<FormatElement *> arguments);
 
   LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
 
@@ -1010,7 +1045,7 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
     auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
     if (!structEl || !literalEl)
       continue;
-    if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
+    if (literalEl->getSpelling() == "," && structEl->hasOptionalElements()) {
       return emitError(loc, "`struct` directive with optional parameters "
                             "cannot be followed by a comma literal");
     }
@@ -1037,17 +1072,17 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
                          "parameters in an optional group must be optional");
       }
     } else if (auto *params = dyn_cast<ParamsDirective>(el)) {
-      if (llvm::any_of(params->getParams(), paramNotOptional)) {
+      if (llvm::any_of(params->getElements(), paramNotOptional)) {
         return emitError(loc, "`params` directive allowed in optional group "
                               "only if all parameters are optional");
       }
     } else if (auto *strct = dyn_cast<StructDirective>(el)) {
-      if (llvm::any_of(strct->getParams(), paramNotOptional)) {
+      if (llvm::any_of(strct->getElements(), formatNotOptional)) {
         return emitError(loc, "`struct` is only allowed in an optional group "
                               "if all captured parameters are optional");
       }
     } else if (auto *custom = dyn_cast<CustomDirective>(el)) {
-      for (FormatElement *el : custom->getArguments()) {
+      for (FormatElement *el : custom->getElements()) {
         // If the custom argument is a variable, then it must be optional.
         if (auto *param = dyn_cast<ParameterElement>(el))
           if (!param->isOptional())
@@ -1068,10 +1103,10 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
     // arguments is a bound parameter.
     if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
       const auto *bound =
-          llvm::find_if(custom->getArguments(), [](FormatElement *el) {
+          llvm::find_if(custom->getElements(), [](FormatElement *el) {
             return isa<ParameterElement>(el);
           });
-      if (bound == custom->getArguments().end())
+      if (bound == custom->getElements().end())
         return emitError(loc, "`custom` directive with no bound parameters "
                               "cannot be used as optional group anchor");
     }
@@ -1079,6 +1114,28 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
   return success();
 }
 
+LogicalResult
+DefFormatParser::verifyStructArguments(SMLoc loc,
+                                       ArrayRef<FormatElement *> arguments) {
+  for (FormatElement *el : arguments) {
+    if (!isa<VariableElement, CustomDirective, ParamsDirective>(el)) {
+      return emitError(loc, "expected a variable, custom directive or params "
+                            "directive in `struct` arguments list");
+    }
+    if (auto custom = dyn_cast<CustomDirective>(el)) {
+      if (custom->getNumElements() != 1) {
+        return emitError(loc, "`struct` can only contain `custom` directives "
+                              "with a single argument");
+      }
+      if (failed(custom->getFrontAs<ParameterElement>())) {
+        return emitError(loc, "a `custom` directive nested within a `struct` "
+                              "must be passed a parameter");
+      }
+    }
+  }
+  return success();
+}
+
 LogicalResult DefFormatParser::markQualified(SMLoc loc,
                                              FormatElement *element) {
   if (!isa<ParameterElement>(element))
@@ -1172,37 +1229,45 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
     return emitError(loc, "`struct` can only be used at the top-level context");
 
   if (failed(parseToken(FormatToken::l_paren,
-                        "expected '(' before `struct` argument list")))
+                        "expected '(' before `struct` argument list"))) {
     return failure();
+  }
 
   // Parse variables captured by `struct`.
-  std::vector<ParameterElement *> vars;
+  std::vector<FormatElement *> vars;
 
   // Parse first captured parameter or a `params` directive.
   FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
-  if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
+  if (failed(var) ||
+      !isa<VariableElement, ParamsDirective, CustomDirective>(*var)) {
     return emitError(loc,
                      "`struct` argument list expected a variable or directive");
   }
-  if (isa<VariableElement>(*var)) {
+  if (isa<VariableElement, CustomDirective>(*var)) {
     // Parse any other parameters.
-    vars.push_back(cast<ParameterElement>(*var));
+    vars.push_back(*var);
     while (peekToken().is(FormatToken::comma)) {
       consumeToken();
       var = parseElement(StructDirectiveContext);
-      if (failed(var) || !isa<VariableElement>(*var))
-        return emitError(loc, "expected a variable in `struct` argument list");
-      vars.push_back(cast<ParameterElement>(*var));
+      if (failed(var) || !isa<VariableElement, CustomDirective>(*var))
+        return emitError(loc, "expected a variable or `custom` directive in "
+                              "`struct` argument list");
+      vars.push_back(*var);
     }
   } else {
     // `struct(params)` captures all parameters in the attribute or type.
-    vars = cast<ParamsDirective>(*var)->takeParams();
+    ParamsDirective *params = cast<ParamsDirective>(*var);
+    vars.reserve(params->getNumElements());
+    for (ParameterElement *el : params->takeElements())
+      vars.push_back(cast<FormatElement>(el));
   }
 
   if (failed(parseToken(FormatToken::r_paren,
-                        "expected ')' at the end of an argument list")))
+                        "expected ')' at the end of an argument list"))) {
+    return failure();
+  }
+  if (failed(verifyStructArguments(loc, vars)))
     return failure();
-
   return create<StructDirective>(std::move(vars));
 }
 
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index dd9b41bc90aef..4dfdde2146679 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -400,8 +400,10 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
 
 FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
                                                               Context ctx) {
-  if (ctx != TopLevelContext)
-    return emitError(loc, "'custom' is only valid as a top-level directive");
+  if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
+    return emitError(loc, "`custom` can only be used at the top-level context "
+                          "or within a `struct` directive");
+  }
 
   FailureOr<FormatToken> nameTok;
   if (failed(parseToken(FormatToken::less,
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 1dc2cb3eaa88a..8e7d49bb37e71 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -338,29 +338,56 @@ class DirectiveElementBase : public DirectiveElement {
   }
 };
 
+/// Base class for a directive that contains references to elements of type `T`
+/// in a vector.
+template <DirectiveElement::Kind DirectiveKind, typename T>
+class VectorDirectiveBase : public DirectiveElementBase<DirectiveKind> {
+public:
+  using Base = VectorDirectiveBase<DirectiveKind, T>;
+
+  VectorDirectiveBase(std::vector<T> &&elems) : elems(std::move(elems)) {}
+
+  /// Get the elements contained in this directive.
+  ArrayRef<T> getElements() const { return elems; }
+
+  /// Get the number of elements.
+  unsigned getNumElements() const { return elems.size(); }
+
+  /// Take all of the elements from this directive.
+  std::vector<T> takeElements() { return std::move(elems); }
+
+protected:
+  /// The elements captured by this directive.
+  std::vector<T> elems;
+};
+
 /// This class represents a custom format directive that is implemented by the
 /// user in C++. The directive accepts a list of arguments that is passed to the
 /// C++ function.
-class CustomDirective : public DirectiveElementBase<DirectiveElement::Custom> {
+class CustomDirective
+    : public VectorDirectiveBase<DirectiveElement::Custom, FormatElement *> {
 public:
+  using Base::Base;
   /// Create a custom directive with a name and list of arguments.
   CustomDirective(StringRef name, std::vector<FormatElement *> &&arguments)
-      : name(name), arguments(std::move(arguments)) {}
+      : Base(std::move(arguments)), name(name) {}
 
   /// Get the custom directive name.
   StringRef getName() const { return name; }
 
-  /// Get the arguments to the custom directive.
-  ArrayRef<FormatElement *> getArguments() const { return arguments; }
+  template <typename T>
+  FailureOr<T *> getFrontAs() const {
+    if (getNumElements() != 1)
+      return failure();
+    if (T *elem = dyn_cast<T>(getElements()[0]))
+      return elem;
+    return failure();
+  }
 
 private:
   /// The name of the custom directive. The name is used to call two C++
   /// methods: `parse{name}` and `print{name}` with the given arguments.
   StringRef name;
-  /// The arguments with which to call the custom functions. These are either
-  /// variables (for which the functions are responsible for populating) or
-  /// references to variables.
-  std::vector<FormatElement *> arguments;
 };
 
 /// This class represents a reference directive. This directive can be used to
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index ca2c1d4a8ad04..a0d947fe8a0df 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -882,7 +882,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
     }
 
   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
-    for (FormatElement *paramElement : custom->getArguments())
+    for (FormatElement *paramElement : custom->getElements())
       genElementParserStorage(paramElement, op, body);
 
   } else if (isa<OperandsDirective>(element)) {
@@ -1037,7 +1037,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
   // * Add a local variable for optional operands and types. This provides a
   //   better API to the user defined parser methods.
   // * Set the location of operand variables.
-  for (FormatElement *param : dir->getArguments()) {
+  for (FormatElement *param : dir->getElements()) {
     if (auto *operand = dyn_cast<OperandVariable>(param)) {
       auto *var = operand->getVar();
       body << "    " << var->name
@@ -1089,7 +1089,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
   }
 
   body << "    auto odsResult = parse" << dir->getName() << "(parser";
-  for (FormatElement *param : dir->getArguments()) {
+  for (FormatElement *param : dir->getElements()) {
     body << ", ";
     genCustomParameterParser(param, body);
   }
@@ -1103,7 +1103,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
   }
 
   // After parsing, add handling for any of the optional constructs.
-  for (FormatElement *param : dir->getArguments()) {
+  for (FormatElement *param : dir->getElements()) {
     if (auto *attr = dyn_cast<AttributeVariable>(param)) {
       const NamedAttribute *var = attr->getVar();
       if (var->attr.isOptional() || var->attr.hasDefaultValue())
@@ -2215,7 +2215,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
 static void genCustomDirectivePrinter(CustomDirective *customDir,
                                       const Operator &op, MethodBody &body) {
   body << "  print" << customDir->getName() << "(_odsPrinter, *this";
-  for (FormatElement *param : customDir->getArguments()) {
+  for (FormatElement *param : customDir->getElements()) {
     body << ", ";
     genCustomDirectiveParameterPrinter(param, op, body);
   }
@@ -2359,7 +2359,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
       .Case([&](CustomDirective *ele) {
         body << '(';
         llvm::interleave(
-            ele->getArguments(), body,
+            ele->getElements(), body,
             [&](FormatElement *child) {
               body << '(';
               genOptionalGroupPrinterAnchor(child, op, body);
@@ -2375,7 +2375,7 @@ void collect(FormatElement *element,
   TypeSwitch<FormatElement *>(element)
       .Case([&](VariableElement *var) { variables.emplace_back(var); })
       .Case([&](CustomDirective *ele) {
-        for (FormatElement *arg : ele->getArguments())
+        for (FormatElement *arg : ele->getElements())
           collect(arg, variables);
       })
       .Case([&](OptionalElement *ele) {
@@ -3774,7 +3774,7 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
           return success();
         // Verify each child as being valid in an optional group. They are all
         // potential anchors if the custom directive was marked as one.
-        for (FormatElement *child : ele->getArguments()) {
+        for (FormatElement *child : ele->getElements()) {
           if (isa<RefDirective>(child))
             continue;
           if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))



More information about the Mlir-commits mailing list