[Mlir-commits] [mlir] [mlir][tblgen] Add custom parsing and printing within struct (PR #133939)
Jorn Tuyls
llvmlistbot at llvm.org
Wed Apr 9 05:38:50 PDT 2025
https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/133939
>From 57da6a99e618bf15f20f0952e21f43ed48826218 Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Tue, 1 Apr 2025 06:52:34 -0500
Subject: [PATCH] [mlir][tblgen] Add `custom` parsing and printing within
`struct`
---
.../test/IR/custom-struct-attr-roundtrip.mlir | 66 +++++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 10 +
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 21 ++
.../attr-or-type-format-invalid.td | 25 +-
mlir/test/mlir-tblgen/attr-or-type-format.td | 78 ++++++
.../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 253 +++++++++++-------
mlir/tools/mlir-tblgen/FormatGen.cpp | 6 +-
mlir/tools/mlir-tblgen/FormatGen.h | 43 ++-
mlir/tools/mlir-tblgen/OpFormatGen.cpp | 16 +-
9 files changed, 404 insertions(+), 114 deletions(-)
create mode 100644 mlir/test/IR/custom-struct-attr-roundtrip.mlir
diff --git a/mlir/test/IR/custom-struct-attr-roundtrip.mlir b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
new file mode 100644
index 0000000000000..4d07e896f5b1d
--- /dev/null
+++ b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_struct_attr_roundtrip
+func.func @test_struct_attr_roundtrip() -> () {
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
+ return
+}
+
+// -----
+
+// Verify all required parameters must be provided. `value` is missing.
+
+// expected-error @below {{struct is missing required parameter: value}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct">} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. All missing.
+
+// expected-error @below {{expected valid keyword}}
+// expected-error @below {{expected a parameter name in struct}}
+"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `type_str` missing.
+
+// expected-error @below {{expected valid keyword}}
+// expected-error @below {{expected a parameter name in struct}}
+"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `value` missing.
+
+// expected-error @below {{expected valid keyword}}
+// expected-error @below {{expected a parameter name in struct}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
+
+// -----
+
+// Verify invalid keyword provided.
+
+// expected-error @below {{duplicate or unknown struct parameter name: type_str2}}
+"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
+
+// -----
+
+// Verify duplicated keyword provided.
+
+// expected-error @below {{duplicate or unknown struct parameter name: type_str}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
+
+// -----
+
+// Verify equals missing.
+
+// expected-error @below {{expected '='}}
+"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index fc2d77af29f12..6441a82d87eba 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
}];
}
+// Test `struct` with nested `custom` assembly format.
+def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
+ let mnemonic = "custom_struct";
+ let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
+ OptionalParameter<"mlir::ArrayAttr">:$opt_value);
+ let assemblyFormat = [{
+ `<` struct($type_str, custom<CustomStructAttr>($value), custom<CustomOptStructFieldAttr>($opt_value)) `>`
+ }];
+}
+
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 057d9fb4a215f..988c6ecc2afb6 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -316,6 +316,27 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
return success();
}
+//===----------------------------------------------------------------------===//
+// TestCustomStructAttr
+//===----------------------------------------------------------------------===//
+
+static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
+ p.printStrippedAttrOrType(value);
+}
+
+static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
+ return p.parseInteger(value);
+}
+
+static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
+ p.printStrippedAttrOrType(attr);
+}
+
+static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
+ ArrayAttr &attr) {
+ return p.parseAttribute(attr);
+}
+
//===----------------------------------------------------------------------===//
// TestOpAsmAttrInterfaceAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index 3a57cbca4d7bb..9a521a5053c0f 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -37,14 +37,14 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
let parameters = (ins "int":$v0);
// CHECK: literals may only be used in the top-level section of the format
- // CHECK: expected a variable in `struct` argument list
+ // CHECK: expected a parameter or `custom` directive in `struct` argument list
let assemblyFormat = "`<` struct($v0, `,`) `>`";
}
// Test struct directive cannot capture zero parameters.
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
let parameters = (ins "int":$v0);
- // CHECK: `struct` argument list expected a variable or directive
+ // CHECK: `struct` argument list expected a parameter or directive
let assemblyFormat = "`<` struct() $v0 `>`";
}
@@ -144,3 +144,24 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> {
// CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor
let assemblyFormat = "$a (`(` custom<Foo>(ref($a))^ `)`)?";
}
+
+// Test `struct` with nested `custom` directive with multiple fields.
+def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ // CHECK: `struct` can only contain `custom` directives with a single argument
+ let assemblyFormat = "struct(custom<Foo>($a, $b))";
+}
+
+// Test `struct` with nested `custom` directive invalid parameter.
+def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> {
+ let parameters = (ins OptionalParameter<"int">:$a);
+ // CHECK: a `custom` directive nested within a `struct` must be passed a parameter
+ let assemblyFormat = "struct($a, custom<Foo>(ref($a)))";
+}
+
+// Test `custom` with nested `custom` directive invalid parameter.
+def InvalidTypeW : InvalidType<"InvalidTypeV", "invalid_v"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ // CHECK: `custom` can only be used at the top-level context or within a `struct` directive
+ let assemblyFormat = "custom<Foo>($a, custom<Bar>($b))";
+}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index c5348409e8e44..0f6b0c401a4e6 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -736,6 +736,84 @@ def TypeS : TestType<"TestS"> {
let assemblyFormat = "$a";
}
+/// Test that a `struct` with nested `custom` parser and printer are generated correctly.
+
+// ATTR: ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR: ::mlir::Type odsType) {
+// ATTR: bool _seen_v0 = false;
+// ATTR: bool _seen_v1 = false;
+// ATTR: bool _seen_v2 = false;
+// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// ATTR: if (odsParser.parseEqual())
+// ATTR: return {};
+// ATTR: if (!_seen_v0 && _paramKey == "v0") {
+// ATTR: _seen_v0 = true;
+// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType);
+// ATTR: if (::mlir::failed(_result_v0))
+// ATTR: return {};
+// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
+// ATTR: _seen_v1 = true;
+// ATTR: {
+// ATTR: auto odsCustomResult = parseNestedCustom(odsParser,
+// ATTR-NEXT: ::mlir::detail::unwrapForCustomParse(_result_v1));
+// ATTR: if (::mlir::failed(odsCustomResult)) return {};
+// ATTR: if (::mlir::failed(_result_v1)) {
+// ATTR: odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'");
+// ATTR: return {};
+// ATTR: }
+// ATTR: }
+// ATTR: } else if (!_seen_v2 && _paramKey == "v2") {
+// ATTR: _seen_v2 = true;
+// ATTR: _result_v2 = ::mlir::FieldParser<AttrParamB>::parse(odsParser);
+// ATTR: if (::mlir::failed(_result_v2)) {
+// ATTR: odsParser.emitError(odsParser.getCurrentLocation(), "failed to parse AttrT parameter 'v2' which is to be a `AttrParamB`");
+// ATTR: return {};
+// ATTR: }
+// ATTR: } else {
+// ATTR: return {};
+// ATTR: }
+// ATTR: return true;
+// ATTR: }
+// ATTR: do {
+// ATTR: ::llvm::StringRef _paramKey;
+// ATTR: if (odsParser.parseKeyword(&_paramKey)) {
+// ATTR: odsParser.emitError(odsParser.getCurrentLocation(),
+// ATTR-NEXT: "expected a parameter name in struct");
+// ATTR: return {};
+// ATTR: }
+// ATTR: if (!_loop_body(_paramKey)) return {};
+// ATTR: } while(!odsParser.parseOptionalComma());
+// ATTR: if (!_seen_v0)
+// ATTR: if (!_seen_v1)
+// ATTR: return TestTAttr::get(odsParser.getContext(),
+// ATTR: TestParamA((*_result_v0)),
+// ATTR: TestParamB((*_result_v1)),
+// ATTR: AttrParamB((_result_v2.value_or(AttrParamB()))));
+// ATTR: }
+
+// ATTR: void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const {
+// ATTR: odsPrinter << "v0 = ";
+// ATTR: ::printAttrParamA(odsPrinter, getV0());
+// ATTR: odsPrinter << ", ";
+// ATTR: odsPrinter << "v1 = ";
+// ATTR: printNestedCustom(odsPrinter,
+// ATTR-NEXT: getV1());
+// ATTR: if (!(getV2() == AttrParamB())) {
+// ATTR: odsPrinter << "v2 = ";
+// ATTR: odsPrinter.printStrippedAttrOrType(getV2());
+// ATTR: }
+
+def AttrT : TestAttr<"TestT"> {
+ let parameters = (ins
+ AttrParamA:$v0,
+ AttrParamB:$v1,
+ OptionalParameter<"AttrParamB">:$v2
+ );
+
+ let mnemonic = "attr_t";
+ let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
+}
+
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index a4ae271edb6bd..d768a9af9421e 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -13,6 +13,7 @@
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -75,37 +76,33 @@ class ParameterElement
AttrOrTypeParameter param;
};
+/// Utility to return a parameter element for the provided `struct` format
+/// element. This parameter can originate from either a `ParameterElement` or a
+/// `CustomDirective` with a single parameter argument.
+static ParameterElement *getStructParameterElement(FormatElement *el) {
+ return TypeSwitch<FormatElement *, ParameterElement *>(el)
+ .Case<ParameterElement>([&](auto param) { return param; })
+ .Case<CustomDirective>([&](auto custom) {
+ FailureOr<ParameterElement *> maybeParam =
+ custom->template getFrontAs<ParameterElement>();
+ return *maybeParam;
+ })
+ .Default([&](auto el) {
+ assert(false && "unexpected struct element type");
+ return nullptr;
+ });
+}
+
/// Shorthand functions that can be used with ranged-based conditions.
static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
+static bool formatIsOptional(FormatElement *el) {
+ ParameterElement *param = getStructParameterElement(el);
+ return param != nullptr && param->isOptional();
+}
static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
-
-/// Base class for a directive that contains references to multiple variables.
-template <DirectiveElement::Kind DirectiveKind>
-class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
-public:
- using Base = ParamsDirectiveBase<DirectiveKind>;
-
- ParamsDirectiveBase(std::vector<ParameterElement *> &¶ms)
- : params(std::move(params)) {}
-
- /// Get the parameters contained in this directive.
- ArrayRef<ParameterElement *> getParams() const { return params; }
-
- /// Get the number of parameters.
- unsigned getNumParams() const { return params.size(); }
-
- /// Take all of the parameters from this directive.
- std::vector<ParameterElement *> takeParams() { return std::move(params); }
-
- /// Returns true if there are optional parameters present.
- bool hasOptionalParams() const {
- return llvm::any_of(getParams(), paramIsOptional);
- }
-
-private:
- /// The parameters captured by this directive.
- std::vector<ParameterElement *> params;
-};
+static bool formatNotOptional(FormatElement *el) {
+ return !formatIsOptional(el);
+}
/// This class represents a `params` directive that refers to all parameters
/// of an attribute or type. When used as a top-level directive, it generates
@@ -116,9 +113,15 @@ class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
/// When used as an argument to another directive that accepts variables,
/// `params` can be used in place of manually listing all parameters of an
/// attribute or type.
-class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
+class ParamsDirective
+ : public VectorDirectiveBase<DirectiveElement::Params, ParameterElement *> {
public:
using Base::Base;
+
+ /// Returns true if there are optional parameters present.
+ bool hasOptionalElements() const {
+ return llvm::any_of(getElements(), paramIsOptional);
+ }
};
/// This class represents a `struct` directive that generates a struct format
@@ -126,9 +129,15 @@ class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
///
/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
///
-class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
+class StructDirective
+ : public VectorDirectiveBase<DirectiveElement::Struct, FormatElement *> {
public:
using Base::Base;
+
+ /// Returns true if there are optional format elements present.
+ bool hasOptionalElements() const {
+ return llvm::any_of(getElements(), formatIsOptional);
+ }
};
} // namespace
@@ -214,10 +223,10 @@ class DefFormat {
/// Generate the printer code for a variable.
void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
bool skipGuard = false);
- /// Generate a printer for comma-separated parameters.
- void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
+ /// Generate a printer for comma-separated format elements.
+ void genCommaSeparatedPrinter(ArrayRef<FormatElement *> params,
FmtContext &ctx, MethodBody &os,
- function_ref<void(ParameterElement *)> extra);
+ function_ref<void(FormatElement *)> extra);
/// Generate the printer code for a `params` directive.
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
@@ -443,14 +452,14 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
// If there are optional parameters, we need to switch to `parseOptionalComma`
// if there are no more required parameters after a certain point.
- bool hasOptional = el->hasOptionalParams();
+ bool hasOptional = el->hasOptionalElements();
if (hasOptional) {
// Wrap everything in a do-while so that we can `break`.
os << "do {\n";
os.indent();
}
- ArrayRef<ParameterElement *> params = el->getParams();
+ ArrayRef<ParameterElement *> params = el->getElements();
using IteratorT = ParameterElement *const *;
IteratorT it = params.begin();
@@ -551,22 +560,31 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
while (!$_parser.parseOptionalComma()) {
)";
+ const char *const checkParamKey = R"(
+ if (!_seen_{0} && _paramKey == "{0}") {
+ _seen_{0} = true;
+)";
+
os << "// Parse parameter struct\n";
// Declare a "seen" variable for each key.
- for (ParameterElement *param : el->getParams())
+ for (FormatElement *arg : el->getElements()) {
+ ParameterElement *param = getStructParameterElement(arg);
os << formatv("bool _seen_{0} = false;\n", param->getName());
+ }
// Generate the body of the parsing loop inside a lambda.
os << "{\n";
os.indent()
<< "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
genLiteralParser("=", ctx, os.indent());
- for (ParameterElement *param : el->getParams()) {
- os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
- " _seen_{0} = true;\n",
- param->getName());
- genVariableParser(param, ctx, os.indent());
+ for (FormatElement *arg : el->getElements()) {
+ ParameterElement *param = getStructParameterElement(arg);
+ os.getStream().printReindented(strfmt(checkParamKey, param->getName()));
+ if (auto realParam = dyn_cast<ParameterElement>(arg))
+ genVariableParser(param, ctx, os.indent());
+ else if (auto custom = dyn_cast<CustomDirective>(arg))
+ genCustomParser(custom, ctx, os.indent());
os.unindent() << "} else ";
// Print the check for duplicate or unknown parameter.
}
@@ -576,10 +594,10 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
// Generate the parsing loop. If optional parameters are present, then the
// parse loop is guarded by commas.
- unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
+ unsigned numOptional = llvm::count_if(el->getElements(), formatIsOptional);
if (numOptional) {
// If the struct itself is optional, pull out the first iteration.
- if (numOptional == el->getNumParams()) {
+ if (numOptional == el->getNumElements()) {
os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
os.indent();
} else {
@@ -587,7 +605,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
}
} else {
os.getStream().printReindented(
- tgfmt(loopHeader, &ctx, el->getNumParams()).str());
+ tgfmt(loopHeader, &ctx, el->getNumElements()).str());
}
os.indent();
os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
@@ -597,12 +615,13 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
// all mandatory parameters have been parsed.
// The whole struct is optional if all its parameters are optional.
if (numOptional) {
- if (numOptional == el->getNumParams()) {
+ if (numOptional == el->getNumElements()) {
os << "}\n";
os.unindent() << "}\n";
} else {
os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
- for (ParameterElement *param : el->getParams()) {
+ for (FormatElement *arg : el->getElements()) {
+ ParameterElement *param = getStructParameterElement(arg);
if (param->isOptional())
continue;
os.getStream().printReindented(
@@ -614,7 +633,8 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
// N flags, successfully exiting the loop means that all parameters have
// been seen. `parseOptionalComma` would cause issues with any formats that
// use "struct(...) `,`" beacuse structs aren't sounded by braces.
- os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
+ os.getStream().printReindented(
+ strfmt(loopTerminator, el->getNumElements()));
}
os.unindent() << "}\n";
}
@@ -631,7 +651,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
os << "(void)odsCustomLoc;\n";
os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
os.indent();
- for (FormatElement *arg : el->getArguments()) {
+ for (FormatElement *arg : el->getElements()) {
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg))
os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
@@ -648,7 +668,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
} else {
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
}
- for (FormatElement *arg : el->getArguments()) {
+ for (FormatElement *arg : el->getElements()) {
if (auto *param = dyn_cast<ParameterElement>(arg)) {
if (param->isOptional())
continue;
@@ -689,7 +709,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
guardOn(llvm::ArrayRef(param));
} else if (auto *params = dyn_cast<ParamsDirective>(first)) {
genParamsParser(params, ctx, os);
- guardOn(params->getParams());
+ guardOn(params->getElements());
} else if (auto *custom = dyn_cast<CustomDirective>(first)) {
os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
os.indent();
@@ -704,7 +724,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
} else {
auto *strct = cast<StructDirective>(first);
genStructParser(strct, ctx, os);
- guardOn(params->getParams());
+ guardOn(params->getElements());
}
os.indent();
@@ -816,14 +836,26 @@ static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &¶ms,
os.indent();
}
+/// Generate code to guard printing on the presence of any optional format
+/// elements.
+template <typename FormatElemRange>
+static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
+ FormatElemRange &&args, bool inverted = false) {
+ guardOnAny(
+ ctx, os,
+ llvm::make_filter_range(llvm::map_range(args, getStructParameterElement),
+ [](ParameterElement *param) { return !!param; }),
+ inverted);
+}
+
void DefFormat::genCommaSeparatedPrinter(
- ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
- function_ref<void(ParameterElement *)> extra) {
+ ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
+ function_ref<void(FormatElement *)> extra) {
// Emit a space if necessary, but only if the struct is present.
if (shouldEmitSpace || !lastWasPunctuation) {
- bool allOptional = llvm::all_of(params, paramIsOptional);
+ bool allOptional = llvm::all_of(args, formatIsOptional);
if (allOptional)
- guardOnAny(ctx, os, params);
+ guardOnAnyOptional(ctx, os, args);
os << tgfmt("$_printer << ' ';\n", &ctx);
if (allOptional)
os.unindent() << "}\n";
@@ -832,17 +864,21 @@ void DefFormat::genCommaSeparatedPrinter(
// The first printed element does not need to emit a comma.
os << "{\n";
os.indent() << "bool _firstPrinted = true;\n";
- for (ParameterElement *param : params) {
+ for (FormatElement *arg : args) {
+ ParameterElement *param = getStructParameterElement(arg);
if (param->isOptional()) {
param->genPrintGuard(ctx, os << "if (") << ") {\n";
os.indent();
}
os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
os << "_firstPrinted = false;\n";
- extra(param);
+ extra(arg);
shouldEmitSpace = false;
lastWasPunctuation = true;
- genVariablePrinter(param, ctx, os);
+ if (auto realParam = dyn_cast<ParameterElement>(arg))
+ genVariablePrinter(realParam, ctx, os);
+ else if (auto custom = dyn_cast<CustomDirective>(arg))
+ genCustomPrinter(custom, ctx, os);
if (param->isOptional())
os.unindent() << "}\n";
}
@@ -851,16 +887,19 @@ void DefFormat::genCommaSeparatedPrinter(
void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
MethodBody &os) {
- genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
- [&](ParameterElement *param) {});
+ SmallVector<FormatElement *> args = llvm::map_to_vector(
+ el->getElements(), [](ParameterElement *param) -> FormatElement * {
+ return static_cast<FormatElement *>(param);
+ });
+ genCommaSeparatedPrinter(args, ctx, os, [&](FormatElement *param) {});
}
void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
MethodBody &os) {
- genCommaSeparatedPrinter(
- llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
- os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
- });
+ genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) {
+ ParameterElement *param = getStructParameterElement(arg);
+ os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
+ });
}
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
@@ -873,7 +912,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
os << tgfmt("print$0($_printer", &ctx, el->getName());
os.indent();
- for (FormatElement *arg : el->getArguments()) {
+ for (FormatElement *arg : el->getElements()) {
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg)) {
os << param->getParam().getAccessorName() << "()";
@@ -893,19 +932,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
} else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
- guardOnAny(ctx, os, params->getParams(), el->isInverted());
+ guardOnAny(ctx, os, params->getElements(), el->isInverted());
} else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
- guardOnAny(ctx, os, strct->getParams(), el->isInverted());
+ guardOnAnyOptional(ctx, os, strct->getElements(), el->isInverted());
} else {
auto *custom = cast<CustomDirective>(anchor);
- guardOnAny(ctx, os,
- llvm::make_filter_range(
- llvm::map_range(custom->getArguments(),
- [](FormatElement *el) {
- return dyn_cast<ParameterElement>(el);
- }),
- [](ParameterElement *param) { return !!param; }),
- el->isInverted());
+ guardOnAnyOptional(ctx, os, custom->getElements(), el->isInverted());
}
// Generate the printer for the contained elements.
{
@@ -960,6 +992,9 @@ class DefFormatParser : public FormatParser {
LogicalResult verifyOptionalGroupElements(SMLoc loc,
ArrayRef<FormatElement *> elements,
FormatElement *anchor) override;
+ /// Verify the arguments to a struct directive.
+ LogicalResult verifyStructArguments(SMLoc loc,
+ ArrayRef<FormatElement *> arguments);
LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
@@ -1010,7 +1045,7 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
if (!structEl || !literalEl)
continue;
- if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
+ if (literalEl->getSpelling() == "," && structEl->hasOptionalElements()) {
return emitError(loc, "`struct` directive with optional parameters "
"cannot be followed by a comma literal");
}
@@ -1037,17 +1072,17 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
"parameters in an optional group must be optional");
}
} else if (auto *params = dyn_cast<ParamsDirective>(el)) {
- if (llvm::any_of(params->getParams(), paramNotOptional)) {
+ if (llvm::any_of(params->getElements(), paramNotOptional)) {
return emitError(loc, "`params` directive allowed in optional group "
"only if all parameters are optional");
}
} else if (auto *strct = dyn_cast<StructDirective>(el)) {
- if (llvm::any_of(strct->getParams(), paramNotOptional)) {
+ if (llvm::any_of(strct->getElements(), formatNotOptional)) {
return emitError(loc, "`struct` is only allowed in an optional group "
"if all captured parameters are optional");
}
} else if (auto *custom = dyn_cast<CustomDirective>(el)) {
- for (FormatElement *el : custom->getArguments()) {
+ for (FormatElement *el : custom->getElements()) {
// If the custom argument is a variable, then it must be optional.
if (auto *param = dyn_cast<ParameterElement>(el))
if (!param->isOptional())
@@ -1068,10 +1103,10 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
// arguments is a bound parameter.
if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
const auto *bound =
- llvm::find_if(custom->getArguments(), [](FormatElement *el) {
+ llvm::find_if(custom->getElements(), [](FormatElement *el) {
return isa<ParameterElement>(el);
});
- if (bound == custom->getArguments().end())
+ if (bound == custom->getElements().end())
return emitError(loc, "`custom` directive with no bound parameters "
"cannot be used as optional group anchor");
}
@@ -1079,6 +1114,28 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
return success();
}
+LogicalResult
+DefFormatParser::verifyStructArguments(SMLoc loc,
+ ArrayRef<FormatElement *> arguments) {
+ for (FormatElement *el : arguments) {
+ if (!isa<ParameterElement, CustomDirective, ParamsDirective>(el)) {
+ return emitError(loc, "expected a parameter, custom directive or params "
+ "directive in `struct` arguments list");
+ }
+ if (auto custom = dyn_cast<CustomDirective>(el)) {
+ if (custom->getNumElements() != 1) {
+ return emitError(loc, "`struct` can only contain `custom` directives "
+ "with a single argument");
+ }
+ if (failed(custom->getFrontAs<ParameterElement>())) {
+ return emitError(loc, "a `custom` directive nested within a `struct` "
+ "must be passed a parameter");
+ }
+ }
+ }
+ return success();
+}
+
LogicalResult DefFormatParser::markQualified(SMLoc loc,
FormatElement *element) {
if (!isa<ParameterElement>(element))
@@ -1172,37 +1229,45 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
return emitError(loc, "`struct` can only be used at the top-level context");
if (failed(parseToken(FormatToken::l_paren,
- "expected '(' before `struct` argument list")))
+ "expected '(' before `struct` argument list"))) {
return failure();
+ }
// Parse variables captured by `struct`.
- std::vector<ParameterElement *> vars;
+ std::vector<FormatElement *> vars;
// Parse first captured parameter or a `params` directive.
FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
- if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
- return emitError(loc,
- "`struct` argument list expected a variable or directive");
+ if (failed(var) ||
+ !isa<ParameterElement, ParamsDirective, CustomDirective>(*var)) {
+ return emitError(
+ loc, "`struct` argument list expected a parameter or directive");
}
- if (isa<VariableElement>(*var)) {
+ if (isa<ParameterElement, CustomDirective>(*var)) {
// Parse any other parameters.
- vars.push_back(cast<ParameterElement>(*var));
+ vars.push_back(*var);
while (peekToken().is(FormatToken::comma)) {
consumeToken();
var = parseElement(StructDirectiveContext);
- if (failed(var) || !isa<VariableElement>(*var))
- return emitError(loc, "expected a variable in `struct` argument list");
- vars.push_back(cast<ParameterElement>(*var));
+ if (failed(var) || !isa<ParameterElement, CustomDirective>(*var))
+ return emitError(loc, "expected a parameter or `custom` directive in "
+ "`struct` argument list");
+ vars.push_back(*var);
}
} else {
// `struct(params)` captures all parameters in the attribute or type.
- vars = cast<ParamsDirective>(*var)->takeParams();
+ ParamsDirective *params = cast<ParamsDirective>(*var);
+ vars.reserve(params->getNumElements());
+ for (ParameterElement *el : params->takeElements())
+ vars.push_back(cast<FormatElement>(el));
}
if (failed(parseToken(FormatToken::r_paren,
- "expected ')' at the end of an argument list")))
+ "expected ')' at the end of an argument list"))) {
+ return failure();
+ }
+ if (failed(verifyStructArguments(loc, vars)))
return failure();
-
return create<StructDirective>(std::move(vars));
}
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index dd9b41bc90aef..4dfdde2146679 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -400,8 +400,10 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
Context ctx) {
- if (ctx != TopLevelContext)
- return emitError(loc, "'custom' is only valid as a top-level directive");
+ if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
+ return emitError(loc, "`custom` can only be used at the top-level context "
+ "or within a `struct` directive");
+ }
FailureOr<FormatToken> nameTok;
if (failed(parseToken(FormatToken::less,
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 1dc2cb3eaa88a..8e7d49bb37e71 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -338,29 +338,56 @@ class DirectiveElementBase : public DirectiveElement {
}
};
+/// Base class for a directive that contains references to elements of type `T`
+/// in a vector.
+template <DirectiveElement::Kind DirectiveKind, typename T>
+class VectorDirectiveBase : public DirectiveElementBase<DirectiveKind> {
+public:
+ using Base = VectorDirectiveBase<DirectiveKind, T>;
+
+ VectorDirectiveBase(std::vector<T> &&elems) : elems(std::move(elems)) {}
+
+ /// Get the elements contained in this directive.
+ ArrayRef<T> getElements() const { return elems; }
+
+ /// Get the number of elements.
+ unsigned getNumElements() const { return elems.size(); }
+
+ /// Take all of the elements from this directive.
+ std::vector<T> takeElements() { return std::move(elems); }
+
+protected:
+ /// The elements captured by this directive.
+ std::vector<T> elems;
+};
+
/// This class represents a custom format directive that is implemented by the
/// user in C++. The directive accepts a list of arguments that is passed to the
/// C++ function.
-class CustomDirective : public DirectiveElementBase<DirectiveElement::Custom> {
+class CustomDirective
+ : public VectorDirectiveBase<DirectiveElement::Custom, FormatElement *> {
public:
+ using Base::Base;
/// Create a custom directive with a name and list of arguments.
CustomDirective(StringRef name, std::vector<FormatElement *> &&arguments)
- : name(name), arguments(std::move(arguments)) {}
+ : Base(std::move(arguments)), name(name) {}
/// Get the custom directive name.
StringRef getName() const { return name; }
- /// Get the arguments to the custom directive.
- ArrayRef<FormatElement *> getArguments() const { return arguments; }
+ template <typename T>
+ FailureOr<T *> getFrontAs() const {
+ if (getNumElements() != 1)
+ return failure();
+ if (T *elem = dyn_cast<T>(getElements()[0]))
+ return elem;
+ return failure();
+ }
private:
/// The name of the custom directive. The name is used to call two C++
/// methods: `parse{name}` and `print{name}` with the given arguments.
StringRef name;
- /// The arguments with which to call the custom functions. These are either
- /// variables (for which the functions are responsible for populating) or
- /// references to variables.
- std::vector<FormatElement *> arguments;
};
/// This class represents a reference directive. This directive can be used to
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index ca2c1d4a8ad04..a0d947fe8a0df 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -882,7 +882,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
}
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
- for (FormatElement *paramElement : custom->getArguments())
+ for (FormatElement *paramElement : custom->getElements())
genElementParserStorage(paramElement, op, body);
} else if (isa<OperandsDirective>(element)) {
@@ -1037,7 +1037,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
// * Add a local variable for optional operands and types. This provides a
// better API to the user defined parser methods.
// * Set the location of operand variables.
- for (FormatElement *param : dir->getArguments()) {
+ for (FormatElement *param : dir->getElements()) {
if (auto *operand = dyn_cast<OperandVariable>(param)) {
auto *var = operand->getVar();
body << " " << var->name
@@ -1089,7 +1089,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
}
body << " auto odsResult = parse" << dir->getName() << "(parser";
- for (FormatElement *param : dir->getArguments()) {
+ for (FormatElement *param : dir->getElements()) {
body << ", ";
genCustomParameterParser(param, body);
}
@@ -1103,7 +1103,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
}
// After parsing, add handling for any of the optional constructs.
- for (FormatElement *param : dir->getArguments()) {
+ for (FormatElement *param : dir->getElements()) {
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional() || var->attr.hasDefaultValue())
@@ -2215,7 +2215,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
static void genCustomDirectivePrinter(CustomDirective *customDir,
const Operator &op, MethodBody &body) {
body << " print" << customDir->getName() << "(_odsPrinter, *this";
- for (FormatElement *param : customDir->getArguments()) {
+ for (FormatElement *param : customDir->getElements()) {
body << ", ";
genCustomDirectiveParameterPrinter(param, op, body);
}
@@ -2359,7 +2359,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
.Case([&](CustomDirective *ele) {
body << '(';
llvm::interleave(
- ele->getArguments(), body,
+ ele->getElements(), body,
[&](FormatElement *child) {
body << '(';
genOptionalGroupPrinterAnchor(child, op, body);
@@ -2375,7 +2375,7 @@ void collect(FormatElement *element,
TypeSwitch<FormatElement *>(element)
.Case([&](VariableElement *var) { variables.emplace_back(var); })
.Case([&](CustomDirective *ele) {
- for (FormatElement *arg : ele->getArguments())
+ for (FormatElement *arg : ele->getElements())
collect(arg, variables);
})
.Case([&](OptionalElement *ele) {
@@ -3774,7 +3774,7 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
return success();
// Verify each child as being valid in an optional group. They are all
// potential anchors if the custom directive was marked as one.
- for (FormatElement *child : ele->getArguments()) {
+ for (FormatElement *child : ele->getElements()) {
if (isa<RefDirective>(child))
continue;
if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))
More information about the Mlir-commits
mailing list