[Mlir-commits] [mlir] 4767e26 - [mlir][ods] Add support for custom directive in attr/type formats
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 15 00:15:20 PDT 2022
Author: Mogball
Date: 2022-03-15T07:15:15Z
New Revision: 4767e267757fa56d40248b759d7b17ac1c6fb2ef
URL: https://github.com/llvm/llvm-project/commit/4767e267757fa56d40248b759d7b17ac1c6fb2ef
DIFF: https://github.com/llvm/llvm-project/commit/4767e267757fa56d40248b759d7b17ac1c6fb2ef.diff
LOG: [mlir][ods] Add support for custom directive in attr/type formats
This patch adds support for custom directives in attribute and type formats. Custom directives dispatch calls to user-defined parser and printer functions.
For example, the assembly format "custom<Foo>($foo, ref($bar))" expects a function with the signature
```
LogicalResult parseFoo(AsmParser &parser, FailureOr<FooT> &foo, BarT bar);
void printFoo(AsmPrinter &printer, FooT foo, BarT bar);
```
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D120944
Added:
Modified:
mlir/docs/Tutorials/DefiningAttributesAndTypes.md
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
mlir/tools/mlir-tblgen/FormatGen.h
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 3260ac10896f4..929749f8b8155 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -558,6 +558,8 @@ Attribute and type assembly formats have the following directives:
mnemonic.
* `struct`: generate a "struct-like" parser and printer for a list of
key-value pairs.
+* `custom`: dispatch a call to user-define parser and printer functions
+* `ref`: in a custom directive, references a previously bound variable
#### `params` Directive
@@ -649,3 +651,44 @@ assembly format of `` `<` struct(params) `>` `` will result in:
The order in which the parameters are printed is the order in which they are
declared in the attribute's or type's `parameter` list.
+
+#### `custom` and `ref` directive
+
+The `custom` directive is used to dispatch calls to user-defined printer and
+parser functions. For example, suppose we had the following type:
+
+```tablegen
+let parameters = (ins "int":$foo, "int":$bar);
+let assemblyFormat = "custom<Foo>($foo) custom<Bar>($bar, ref($foo))";
+```
+
+The `custom` directive `custom<Foo>($foo)` will in the parser and printer
+respectively generate calls to:
+
+```c++
+LogicalResult parseFoo(AsmParser &parser, FailureOr<int> &foo);
+void printFoo(AsmPrinter &printer, int foo);
+```
+
+A previously bound variable can be passed as a parameter to a `custom` directive
+by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by
+the first directive. The second directive references it and expects the
+following printer and parser signatures:
+
+```c++
+LogicalResult parseBar(AsmParser &parser, FailureOr<int> &bar, int foo);
+void printBar(AsmPrinter &printer, int bar, int foo);
+```
+
+More complex C++ types can be used with the `custom` directive. The only caveat
+is that the parameter for the parser must use the storage type of the parameter.
+For example, `StringRefParameter` expects the parser and printer signatures as:
+
+```c++
+LogicalResult parseStringParam(AsmParser &parser,
+ FailureOr<std::string> &value);
+void printStringParam(AsmPrinter &printer, StringRef value);
+```
+
+The custom parser is considered to have failed if it returns failure or if any
+bound parameters have failure values afterwards.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index de4969353ad01..75cccdaf8fd45 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -363,4 +363,18 @@ def TestTypeDefaultValuedType : Test_Type<"TestTypeDefaultValuedType"> {
let assemblyFormat = "`<` (`(` $type^ `)`)? `>`";
}
+def TestTypeCustom : Test_Type<"TestTypeCustom"> {
+ let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional<int>">:$b);
+ let mnemonic = "custom_type";
+ let assemblyFormat = [{ `<` custom<CustomTypeA>($a)
+ custom<CustomTypeB>(ref($a), $b) `>` }];
+}
+
+def TestTypeCustomString : Test_Type<"TestTypeCustomString"> {
+ let parameters = (ins StringRefParameter<>:$foo);
+ let mnemonic = "custom_type_string";
+ let assemblyFormat = [{ `<` custom<FooString>($foo)
+ custom<BarString>(ref($foo)) `>` }];
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 9b65f2fc06a37..5bcf62a3fa272 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -208,6 +208,59 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
return 1;
}
+//===----------------------------------------------------------------------===//
+// TestCustomType
+//===----------------------------------------------------------------------===//
+
+static LogicalResult parseCustomTypeA(AsmParser &parser,
+ FailureOr<int> &a_result) {
+ a_result.emplace();
+ return parser.parseInteger(*a_result);
+}
+
+static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
+
+static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
+ FailureOr<Optional<int>> &b_result) {
+ if (a < 0)
+ return success();
+ for (int i : llvm::seq(0, a))
+ if (failed(parser.parseInteger(i)))
+ return failure();
+ b_result.emplace(0);
+ return parser.parseInteger(**b_result);
+}
+
+static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
+ if (a < 0)
+ return;
+ printer << ' ';
+ for (int i : llvm::seq(0, a))
+ printer << i << ' ';
+ printer << *b;
+}
+
+static LogicalResult parseFooString(AsmParser &parser,
+ FailureOr<std::string> &foo) {
+ std::string result;
+ if (parser.parseString(&result))
+ return failure();
+ foo = std::move(result);
+ return success();
+}
+
+static void printFooString(AsmPrinter &printer, StringRef foo) {
+ printer << '"' << foo << '"';
+}
+
+static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
+ return parser.parseKeyword(foo);
+}
+
+static void printBarString(AsmPrinter &printer, StringRef foo) {
+ printer << ' ' << foo;
+}
+
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index d92ae1677100c..ac8e5974387f3 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -107,3 +107,27 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
// CHECK: optional group anchor must be a parameter or directive
let assemblyFormat = "(`(` $a `)`^)?";
}
+
+def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
+ let parameters = (ins "int":$a);
+ // CHECK: `ref` is only allowed inside custom directives
+ let assemblyFormat = "$a ref($a)";
+}
+
+def InvalidTypeP : InvalidType<"InvalidTypeP", "invalid_p"> {
+ let parameters = (ins "int":$a);
+ // CHECK: parameter 'a' must be bound before it is referenced
+ let assemblyFormat = "custom<Foo>(ref($a)) $a";
+}
+
+def InvalidTypeQ : InvalidType<"InvalidTypeQ", "invalid_q"> {
+ let parameters = (ins "int":$a);
+ // CHECK: `params` can only be used at the top-level context or within a `struct` directive
+ let assemblyFormat = "custom<Foo>(params)";
+}
+
+def InvalidTypeR : InvalidType<"InvalidTypeR", "invalid_r"> {
+ let parameters = (ins "int":$a);
+ // CHECK: `struct` can only be used at the top-level context
+ let assemblyFormat = "custom<Foo>(struct(params))";
+}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index b626ca44f36a1..bb8fffd9134fb 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -48,6 +48,10 @@ attributes {
// CHECK: !test.ap_float<>
// CHECK: !test.default_valued_type<(i64)>
// CHECK: !test.default_valued_type<>
+// CHECK: !test.custom_type<-5>
+// CHECK: !test.custom_type<2 0 1 5>
+// CHECK: !test.custom_type_string<"foo" foo>
+// CHECK: !test.custom_type_string<"bar" bar>
func private @test_roundtrip_default_parsers_struct(
!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
@@ -79,5 +83,9 @@ func private @test_roundtrip_default_parsers_struct(
!test.ap_float<5.0>,
!test.ap_float<>,
!test.default_valued_type<(i64)>,
- !test.default_valued_type<>
+ !test.default_valued_type<>,
+ !test.custom_type<-5>,
+ !test.custom_type<2 9 9 5>,
+ !test.custom_type_string<"foo" foo>,
+ !test.custom_type_string<"bar" bar>
)
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index f0c1aac4af446..ba8df2b593f9b 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -499,3 +499,27 @@ def TypeI : TestType<"TestK"> {
let mnemonic = "type_k";
let assemblyFormat = "$a";
}
+
+// TYPE: ::mlir::Type TestLType::parse
+// TYPE: auto odsCustomLoc = odsParser.getCurrentLocation()
+// TYPE: auto odsCustomResult = parseA(odsParser,
+// TYPE-NEXT: _result_a
+// TYPE: if (::mlir::failed(odsCustomResult)) return {}
+// TYPE: if (::mlir::failed(_result_a))
+// TYPE-NEXT: odsParser.emitError(odsCustomLoc,
+// TYPE: auto odsCustomResult = parseB(odsParser,
+// TYPE-NEXT: _result_b
+// TYPE-NEXT: *_result_a
+
+// TYPE: void TestLType::print
+// TYPE: printA(odsPrinter
+// TYPE-NEXT: getA()
+// TYPE: printB(odsPrinter
+// TYPE-NEXT: getB()
+// TYPE-NEXT: getA()
+
+def TypeJ : TestType<"TestL"> {
+ let parameters = (ins "int":$a, OptionalParameter<"Attribute">:$b);
+ let mnemonic = "type_j";
+ let assemblyFormat = "custom<A>($a) custom<B>($b, ref($a))";
+}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index f8b3b2b007a8d..0c314b33caf82 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -199,6 +199,8 @@ class DefFormat {
void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `struct` directive.
void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
+ /// Generate the parser code for a `custom` directive.
+ void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for an optional group.
void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os);
@@ -218,6 +220,8 @@ class DefFormat {
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
+ /// Generate the printer code for a `custom` directive.
+ void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for an optional group.
void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os);
@@ -313,6 +317,8 @@ void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
return genParamsParser(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructParser(strct, ctx, os);
+ if (auto *custom = dyn_cast<CustomDirective>(el))
+ return genCustomParser(custom, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el))
return genOptionalGroupParser(optional, ctx, os);
if (isa<WhitespaceElement>(el))
@@ -566,6 +572,47 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
os.unindent() << "}\n";
}
+void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
+ MethodBody &os) {
+ os << "{\n";
+ os.indent();
+
+ // Bound variables are passed directly to the parser as `FailureOr<T> &`.
+ // Referenced variables are passed as `T`. The custom parser fails if it
+ // returns failure or if any of the required parameters failed.
+ os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
+ os << "(void)odsCustomLoc;\n";
+ os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
+ os.indent();
+ for (FormatElement *arg : el->getArguments()) {
+ os << ",\n";
+ FormatElement *param;
+ if (auto *ref = dyn_cast<RefDirective>(arg)) {
+ os << "*";
+ param = ref->getArg();
+ } else {
+ param = arg;
+ }
+ os << "_result_" << cast<ParameterElement>(param)->getName();
+ }
+ os.unindent() << ");\n";
+ os << "if (::mlir::failed(odsCustomResult)) return {};\n";
+ for (FormatElement *arg : el->getArguments()) {
+ if (auto *param = dyn_cast<ParameterElement>(arg)) {
+ if (param->isOptional())
+ continue;
+ os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
+ os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
+ << "\"custom parser failed to parse parameter '"
+ << param->getName() << "'\");\n";
+ os << "return {};\n";
+ os.unindent() << "}\n";
+ }
+ }
+
+ os.unindent() << "}\n";
+}
+
void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
ArrayRef<FormatElement *> elements =
@@ -634,6 +681,8 @@ void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
return genParamsPrinter(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os);
+ if (auto *custom = dyn_cast<CustomDirective>(el))
+ return genCustomPrinter(custom, ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el))
return genVariablePrinter(var, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el))
@@ -746,6 +795,21 @@ void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
});
}
+void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
+ MethodBody &os) {
+ os << tgfmt("print$0($_printer", &ctx, el->getName());
+ os.indent();
+ for (FormatElement *arg : el->getArguments()) {
+ FormatElement *param = arg;
+ if (auto *ref = dyn_cast<RefDirective>(arg))
+ param = ref->getArg();
+ os << ",\n"
+ << getParameterAccessorName(cast<ParameterElement>(param)->getName())
+ << "()";
+ }
+ os.unindent() << ");\n";
+}
+
void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
FormatElement *anchor = el->getAnchor();
@@ -805,9 +869,7 @@ class DefFormatParser : public FormatParser {
/// Verify the elements of a custom directive.
LogicalResult
verifyCustomDirectiveArguments(SMLoc loc,
- ArrayRef<FormatElement *> arguments) override {
- return emitError(loc, "'custom' not supported (yet)");
- }
+ ArrayRef<FormatElement *> arguments) override;
/// Verify the elements of an optional group.
LogicalResult
verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
@@ -822,11 +884,13 @@ class DefFormatParser : public FormatParser {
private:
/// Parse a `params` directive.
- FailureOr<FormatElement *> parseParamsDirective(SMLoc loc);
+ FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
/// Parse a `qualified` directive.
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
/// Parse a `struct` directive.
- FailureOr<FormatElement *> parseStructDirective(SMLoc loc);
+ FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
+ /// Parse a `ref` directive.
+ FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
/// Attribute or type tablegen def.
const AttrOrTypeDef &def;
@@ -862,6 +926,12 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
return success();
}
+LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
+ SMLoc loc, ArrayRef<FormatElement *> arguments) {
+ // Arguments are fully verified by the parser context.
+ return success();
+}
+
LogicalResult
DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
ArrayRef<FormatElement *> elements,
@@ -915,9 +985,18 @@ DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
def.getName() + " has no parameter named '" + name + "'");
}
auto idx = std::distance(params.begin(), it);
- if (seenParams.test(idx))
- return emitError(loc, "duplicate parameter '" + name + "'");
- seenParams.set(idx);
+
+ if (ctx != RefDirectiveContext) {
+ // Check that the variable has not already been bound.
+ if (seenParams.test(idx))
+ return emitError(loc, "duplicate parameter '" + name + "'");
+ seenParams.set(idx);
+
+ // Otherwise, to be referenced, a variable must have been bound.
+ } else if (!seenParams.test(idx)) {
+ return emitError(loc, "parameter '" + name +
+ "' must be bound before it is referenced");
+ }
return create<ParameterElement>(*it);
}
@@ -930,14 +1009,13 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
case FormatToken::kw_qualified:
return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_params:
- return parseParamsDirective(loc);
+ return parseParamsDirective(loc, ctx);
case FormatToken::kw_struct:
- if (ctx != TopLevelContext) {
- return emitError(
- loc,
- "`struct` may only be used in the top-level section of the format");
- }
- return parseStructDirective(loc);
+ return parseStructDirective(loc, ctx);
+ case FormatToken::kw_ref:
+ return parseRefDirective(loc, ctx);
+ case FormatToken::kw_custom:
+ return parseCustomDirective(loc, ctx);
default:
return emitError(loc, "unsupported directive kind");
@@ -961,10 +1039,18 @@ DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
return var;
}
-FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
- // Collect all of the attribute's or type's parameters.
+FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
+ Context ctx) {
+ // It doesn't make sense to allow references to all parameters in a custom
+ // directive because parameters are the only things that can be bound.
+ if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
+ return emitError(loc, "`params` can only be used at the top-level context "
+ "or within a `struct` directive");
+ }
+
+ // Collect all of the attribute's or type's parameters and ensure that none of
+ // the parameters have already been captured.
std::vector<ParameterElement *> vars;
- // Ensure that none of the parameters have already been captured.
for (const auto &it : llvm::enumerate(def.getParameters())) {
if (seenParams.test(it.index())) {
return emitError(loc, "`params` captures duplicate parameter: " +
@@ -976,7 +1062,11 @@ FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
return create<ParamsDirective>(std::move(vars));
}
-FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
+FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
+ Context ctx) {
+ if (ctx != TopLevelContext)
+ return emitError(loc, "`struct` can only be used at the top-level context");
+
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before `struct` argument list")))
return failure();
@@ -1012,6 +1102,22 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
return create<StructDirective>(std::move(vars));
}
+FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
+ Context ctx) {
+ if (ctx != CustomDirectiveContext)
+ return emitError(loc, "`ref` is only allowed inside custom directives");
+
+ // Parse the child parameter element.
+ FailureOr<FormatElement *> child;
+ if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
+ failed(child = parseElement(RefDirectiveContext)) ||
+ failed(parseToken(FormatToken::r_paren, "expeced ')'")))
+ return failure();
+
+ // Only parameter elements are allowed to be parsed under a `ref` directive.
+ return create<RefDirective>(*child);
+}
+
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 741e2716f0388..f180f2da48e8d 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -338,6 +338,22 @@ class CustomDirective : public DirectiveElementBase<DirectiveElement::Custom> {
std::vector<FormatElement *> arguments;
};
+/// This class represents a reference directive. This directive can be used to
+/// reference but not bind a previously bound variable or format object. Its
+/// current only use is to pass variables as arguments to the custom directive.
+class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
+public:
+ /// Create a reference directive with the single referenced child.
+ RefDirective(FormatElement *arg) : arg(arg) {}
+
+ /// Get the reference argument.
+ FormatElement *getArg() const { return arg; }
+
+private:
+ /// The referenced argument.
+ FormatElement *arg;
+};
+
/// This class represents a group of elements that are optionally emitted based
/// on an optional variable "anchor" and a group of elements that are emitted
/// when the anchor element is not present.
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0d970d82aa3f3..044f1b01c3bd3 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -153,18 +153,6 @@ class FunctionalTypeDirective
FormatElement *inputs, *results;
};
-/// This class represents the `ref` directive.
-class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
-public:
- RefDirective(FormatElement *arg) : arg(arg) {}
-
- FormatElement *getArg() const { return arg; }
-
-private:
- /// The argument that is used to format the directive.
- FormatElement *arg;
-};
-
/// This class represents the `type` directive.
class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
public:
More information about the Mlir-commits
mailing list