[Mlir-commits] [mlir] 4767e26 - [mlir][ods] Add support for custom directive in attr/type formats

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 15 00:15:20 PDT 2022


Author: Mogball
Date: 2022-03-15T07:15:15Z
New Revision: 4767e267757fa56d40248b759d7b17ac1c6fb2ef

URL: https://github.com/llvm/llvm-project/commit/4767e267757fa56d40248b759d7b17ac1c6fb2ef
DIFF: https://github.com/llvm/llvm-project/commit/4767e267757fa56d40248b759d7b17ac1c6fb2ef.diff

LOG: [mlir][ods] Add support for custom directive in attr/type formats

This patch adds support for custom directives in attribute and type formats. Custom directives dispatch calls to user-defined parser and printer functions.

For example, the assembly format "custom<Foo>($foo, ref($bar))" expects a function with the signature

```
LogicalResult parseFoo(AsmParser &parser, FailureOr<FooT> &foo, BarT bar);
void printFoo(AsmPrinter &printer, FooT foo, BarT bar);
```

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D120944

Added: 
    

Modified: 
    mlir/docs/Tutorials/DefiningAttributesAndTypes.md
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
    mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.h
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 3260ac10896f4..929749f8b8155 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -558,6 +558,8 @@ Attribute and type assembly formats have the following directives:
     mnemonic.
 *   `struct`: generate a "struct-like" parser and printer for a list of
     key-value pairs.
+*   `custom`: dispatch a call to user-define parser and printer functions
+*   `ref`: in a custom directive, references a previously bound variable
 
 #### `params` Directive
 
@@ -649,3 +651,44 @@ assembly format of `` `<` struct(params) `>` `` will result in:
 
 The order in which the parameters are printed is the order in which they are
 declared in the attribute's or type's `parameter` list.
+
+#### `custom` and `ref` directive
+
+The `custom` directive is used to dispatch calls to user-defined printer and
+parser functions. For example, suppose we had the following type:
+
+```tablegen
+let parameters = (ins "int":$foo, "int":$bar);
+let assemblyFormat = "custom<Foo>($foo) custom<Bar>($bar, ref($foo))";
+```
+
+The `custom` directive `custom<Foo>($foo)` will in the parser and printer
+respectively generate calls to:
+
+```c++
+LogicalResult parseFoo(AsmParser &parser, FailureOr<int> &foo);
+void printFoo(AsmPrinter &printer, int foo);
+```
+
+A previously bound variable can be passed as a parameter to a `custom` directive
+by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by
+the first directive. The second directive references it and expects the
+following printer and parser signatures:
+
+```c++
+LogicalResult parseBar(AsmParser &parser, FailureOr<int> &bar, int foo);
+void printBar(AsmPrinter &printer, int bar, int foo);
+```
+
+More complex C++ types can be used with the `custom` directive. The only caveat
+is that the parameter for the parser must use the storage type of the parameter.
+For example, `StringRefParameter` expects the parser and printer signatures as:
+
+```c++
+LogicalResult parseStringParam(AsmParser &parser,
+                               FailureOr<std::string> &value);
+void printStringParam(AsmPrinter &printer, StringRef value);
+```
+
+The custom parser is considered to have failed if it returns failure or if any
+bound parameters have failure values afterwards.

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index de4969353ad01..75cccdaf8fd45 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -363,4 +363,18 @@ def TestTypeDefaultValuedType : Test_Type<"TestTypeDefaultValuedType"> {
   let assemblyFormat = "`<` (`(` $type^ `)`)? `>`";
 }
 
+def TestTypeCustom : Test_Type<"TestTypeCustom"> {
+  let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "custom_type";
+  let assemblyFormat = [{ `<` custom<CustomTypeA>($a)
+                              custom<CustomTypeB>(ref($a), $b) `>` }];
+}
+
+def TestTypeCustomString : Test_Type<"TestTypeCustomString"> {
+  let parameters = (ins StringRefParameter<>:$foo);
+  let mnemonic = "custom_type_string";
+  let assemblyFormat = [{ `<` custom<FooString>($foo)
+                              custom<BarString>(ref($foo)) `>` }];
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 9b65f2fc06a37..5bcf62a3fa272 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -208,6 +208,59 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
   return 1;
 }
 
+//===----------------------------------------------------------------------===//
+// TestCustomType
+//===----------------------------------------------------------------------===//
+
+static LogicalResult parseCustomTypeA(AsmParser &parser,
+                                      FailureOr<int> &a_result) {
+  a_result.emplace();
+  return parser.parseInteger(*a_result);
+}
+
+static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
+
+static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
+                                      FailureOr<Optional<int>> &b_result) {
+  if (a < 0)
+    return success();
+  for (int i : llvm::seq(0, a))
+    if (failed(parser.parseInteger(i)))
+      return failure();
+  b_result.emplace(0);
+  return parser.parseInteger(**b_result);
+}
+
+static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
+  if (a < 0)
+    return;
+  printer << ' ';
+  for (int i : llvm::seq(0, a))
+    printer << i << ' ';
+  printer << *b;
+}
+
+static LogicalResult parseFooString(AsmParser &parser,
+                                    FailureOr<std::string> &foo) {
+  std::string result;
+  if (parser.parseString(&result))
+    return failure();
+  foo = std::move(result);
+  return success();
+}
+
+static void printFooString(AsmPrinter &printer, StringRef foo) {
+  printer << '"' << foo << '"';
+}
+
+static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
+  return parser.parseKeyword(foo);
+}
+
+static void printBarString(AsmPrinter &printer, StringRef foo) {
+  printer << ' ' << foo;
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index d92ae1677100c..ac8e5974387f3 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -107,3 +107,27 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
   // CHECK: optional group anchor must be a parameter or directive
   let assemblyFormat = "(`(` $a `)`^)?";
 }
+
+def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
+  let parameters = (ins "int":$a);
+  // CHECK: `ref` is only allowed inside custom directives
+  let assemblyFormat = "$a ref($a)";
+}
+
+def InvalidTypeP : InvalidType<"InvalidTypeP", "invalid_p"> {
+  let parameters = (ins "int":$a);
+  // CHECK: parameter 'a' must be bound before it is referenced
+  let assemblyFormat = "custom<Foo>(ref($a)) $a";
+}
+
+def InvalidTypeQ : InvalidType<"InvalidTypeQ", "invalid_q"> {
+  let parameters = (ins "int":$a);
+  // CHECK: `params` can only be used at the top-level context or within a `struct` directive
+  let assemblyFormat = "custom<Foo>(params)";
+}
+
+def InvalidTypeR : InvalidType<"InvalidTypeR", "invalid_r"> {
+  let parameters = (ins "int":$a);
+  // CHECK: `struct` can only be used at the top-level context
+  let assemblyFormat = "custom<Foo>(struct(params))";
+}

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index b626ca44f36a1..bb8fffd9134fb 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -48,6 +48,10 @@ attributes {
 // CHECK: !test.ap_float<>
 // CHECK: !test.default_valued_type<(i64)>
 // CHECK: !test.default_valued_type<>
+// CHECK: !test.custom_type<-5>
+// CHECK: !test.custom_type<2 0 1 5>
+// CHECK: !test.custom_type_string<"foo" foo>
+// CHECK: !test.custom_type_string<"bar" bar>
 
 func private @test_roundtrip_default_parsers_struct(
   !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
@@ -79,5 +83,9 @@ func private @test_roundtrip_default_parsers_struct(
   !test.ap_float<5.0>,
   !test.ap_float<>,
   !test.default_valued_type<(i64)>,
-  !test.default_valued_type<>
+  !test.default_valued_type<>,
+  !test.custom_type<-5>,
+  !test.custom_type<2 9 9 5>,
+  !test.custom_type_string<"foo" foo>,
+  !test.custom_type_string<"bar" bar>
 )

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index f0c1aac4af446..ba8df2b593f9b 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -499,3 +499,27 @@ def TypeI : TestType<"TestK"> {
   let mnemonic = "type_k";
   let assemblyFormat = "$a";
 }
+
+// TYPE: ::mlir::Type TestLType::parse
+// TYPE:   auto odsCustomLoc = odsParser.getCurrentLocation()
+// TYPE:   auto odsCustomResult = parseA(odsParser,
+// TYPE-NEXT: _result_a
+// TYPE:   if (::mlir::failed(odsCustomResult)) return {}
+// TYPE:   if (::mlir::failed(_result_a))
+// TYPE-NEXT: odsParser.emitError(odsCustomLoc,
+// TYPE:   auto odsCustomResult = parseB(odsParser,
+// TYPE-NEXT: _result_b
+// TYPE-NEXT: *_result_a
+
+// TYPE: void TestLType::print
+// TYPE:   printA(odsPrinter
+// TYPE-NEXT: getA()
+// TYPE:   printB(odsPrinter
+// TYPE-NEXT: getB()
+// TYPE-NEXT: getA()
+
+def TypeJ : TestType<"TestL"> {
+  let parameters = (ins "int":$a, OptionalParameter<"Attribute">:$b);
+  let mnemonic = "type_j";
+  let assemblyFormat = "custom<A>($a) custom<B>($b, ref($a))";
+}

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index f8b3b2b007a8d..0c314b33caf82 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -199,6 +199,8 @@ class DefFormat {
   void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the parser code for a `struct` directive.
   void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
+  /// Generate the parser code for a `custom` directive.
+  void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the parser code for an optional group.
   void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
                               MethodBody &os);
@@ -218,6 +220,8 @@ class DefFormat {
   void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the printer code for a `struct` directive.
   void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
+  /// Generate the printer code for a `custom` directive.
+  void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the printer code for an optional group.
   void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
                                MethodBody &os);
@@ -313,6 +317,8 @@ void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
     return genParamsParser(params, ctx, os);
   if (auto *strct = dyn_cast<StructDirective>(el))
     return genStructParser(strct, ctx, os);
+  if (auto *custom = dyn_cast<CustomDirective>(el))
+    return genCustomParser(custom, ctx, os);
   if (auto *optional = dyn_cast<OptionalElement>(el))
     return genOptionalGroupParser(optional, ctx, os);
   if (isa<WhitespaceElement>(el))
@@ -566,6 +572,47 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
   os.unindent() << "}\n";
 }
 
+void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
+                                MethodBody &os) {
+  os << "{\n";
+  os.indent();
+
+  // Bound variables are passed directly to the parser as `FailureOr<T> &`.
+  // Referenced variables are passed as `T`. The custom parser fails if it
+  // returns failure or if any of the required parameters failed.
+  os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
+  os << "(void)odsCustomLoc;\n";
+  os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
+  os.indent();
+  for (FormatElement *arg : el->getArguments()) {
+    os << ",\n";
+    FormatElement *param;
+    if (auto *ref = dyn_cast<RefDirective>(arg)) {
+      os << "*";
+      param = ref->getArg();
+    } else {
+      param = arg;
+    }
+    os << "_result_" << cast<ParameterElement>(param)->getName();
+  }
+  os.unindent() << ");\n";
+  os << "if (::mlir::failed(odsCustomResult)) return {};\n";
+  for (FormatElement *arg : el->getArguments()) {
+    if (auto *param = dyn_cast<ParameterElement>(arg)) {
+      if (param->isOptional())
+        continue;
+      os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
+      os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
+                  << "\"custom parser failed to parse parameter '"
+                  << param->getName() << "'\");\n";
+      os << "return {};\n";
+      os.unindent() << "}\n";
+    }
+  }
+
+  os.unindent() << "}\n";
+}
+
 void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
                                        MethodBody &os) {
   ArrayRef<FormatElement *> elements =
@@ -634,6 +681,8 @@ void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
     return genParamsPrinter(params, ctx, os);
   if (auto *strct = dyn_cast<StructDirective>(el))
     return genStructPrinter(strct, ctx, os);
+  if (auto *custom = dyn_cast<CustomDirective>(el))
+    return genCustomPrinter(custom, ctx, os);
   if (auto *var = dyn_cast<ParameterElement>(el))
     return genVariablePrinter(var, ctx, os);
   if (auto *optional = dyn_cast<OptionalElement>(el))
@@ -746,6 +795,21 @@ void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
       });
 }
 
+void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
+                                 MethodBody &os) {
+  os << tgfmt("print$0($_printer", &ctx, el->getName());
+  os.indent();
+  for (FormatElement *arg : el->getArguments()) {
+    FormatElement *param = arg;
+    if (auto *ref = dyn_cast<RefDirective>(arg))
+      param = ref->getArg();
+    os << ",\n"
+       << getParameterAccessorName(cast<ParameterElement>(param)->getName())
+       << "()";
+  }
+  os.unindent() << ");\n";
+}
+
 void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
                                         MethodBody &os) {
   FormatElement *anchor = el->getAnchor();
@@ -805,9 +869,7 @@ class DefFormatParser : public FormatParser {
   /// Verify the elements of a custom directive.
   LogicalResult
   verifyCustomDirectiveArguments(SMLoc loc,
-                                 ArrayRef<FormatElement *> arguments) override {
-    return emitError(loc, "'custom' not supported (yet)");
-  }
+                                 ArrayRef<FormatElement *> arguments) override;
   /// Verify the elements of an optional group.
   LogicalResult
   verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
@@ -822,11 +884,13 @@ class DefFormatParser : public FormatParser {
 
 private:
   /// Parse a `params` directive.
-  FailureOr<FormatElement *> parseParamsDirective(SMLoc loc);
+  FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
   /// Parse a `qualified` directive.
   FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
   /// Parse a `struct` directive.
-  FailureOr<FormatElement *> parseStructDirective(SMLoc loc);
+  FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
+  /// Parse a `ref` directive.
+  FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
 
   /// Attribute or type tablegen def.
   const AttrOrTypeDef &def;
@@ -862,6 +926,12 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
   return success();
 }
 
+LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
+    SMLoc loc, ArrayRef<FormatElement *> arguments) {
+  // Arguments are fully verified by the parser context.
+  return success();
+}
+
 LogicalResult
 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
                                              ArrayRef<FormatElement *> elements,
@@ -915,9 +985,18 @@ DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
                      def.getName() + " has no parameter named '" + name + "'");
   }
   auto idx = std::distance(params.begin(), it);
-  if (seenParams.test(idx))
-    return emitError(loc, "duplicate parameter '" + name + "'");
-  seenParams.set(idx);
+
+  if (ctx != RefDirectiveContext) {
+    // Check that the variable has not already been bound.
+    if (seenParams.test(idx))
+      return emitError(loc, "duplicate parameter '" + name + "'");
+    seenParams.set(idx);
+
+    // Otherwise, to be referenced, a variable must have been bound.
+  } else if (!seenParams.test(idx)) {
+    return emitError(loc, "parameter '" + name +
+                              "' must be bound before it is referenced");
+  }
 
   return create<ParameterElement>(*it);
 }
@@ -930,14 +1009,13 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
   case FormatToken::kw_qualified:
     return parseQualifiedDirective(loc, ctx);
   case FormatToken::kw_params:
-    return parseParamsDirective(loc);
+    return parseParamsDirective(loc, ctx);
   case FormatToken::kw_struct:
-    if (ctx != TopLevelContext) {
-      return emitError(
-          loc,
-          "`struct` may only be used in the top-level section of the format");
-    }
-    return parseStructDirective(loc);
+    return parseStructDirective(loc, ctx);
+  case FormatToken::kw_ref:
+    return parseRefDirective(loc, ctx);
+  case FormatToken::kw_custom:
+    return parseCustomDirective(loc, ctx);
 
   default:
     return emitError(loc, "unsupported directive kind");
@@ -961,10 +1039,18 @@ DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
   return var;
 }
 
-FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
-  // Collect all of the attribute's or type's parameters.
+FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
+                                                                 Context ctx) {
+  // It doesn't make sense to allow references to all parameters in a custom
+  // directive because parameters are the only things that can be bound.
+  if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
+    return emitError(loc, "`params` can only be used at the top-level context "
+                          "or within a `struct` directive");
+  }
+
+  // Collect all of the attribute's or type's parameters and ensure that none of
+  // the parameters have already been captured.
   std::vector<ParameterElement *> vars;
-  // Ensure that none of the parameters have already been captured.
   for (const auto &it : llvm::enumerate(def.getParameters())) {
     if (seenParams.test(it.index())) {
       return emitError(loc, "`params` captures duplicate parameter: " +
@@ -976,7 +1062,11 @@ FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
   return create<ParamsDirective>(std::move(vars));
 }
 
-FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
+FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
+                                                                 Context ctx) {
+  if (ctx != TopLevelContext)
+    return emitError(loc, "`struct` can only be used at the top-level context");
+
   if (failed(parseToken(FormatToken::l_paren,
                         "expected '(' before `struct` argument list")))
     return failure();
@@ -1012,6 +1102,22 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
   return create<StructDirective>(std::move(vars));
 }
 
+FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
+                                                              Context ctx) {
+  if (ctx != CustomDirectiveContext)
+    return emitError(loc, "`ref` is only allowed inside custom directives");
+
+  // Parse the child parameter element.
+  FailureOr<FormatElement *> child;
+  if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
+      failed(child = parseElement(RefDirectiveContext)) ||
+      failed(parseToken(FormatToken::r_paren, "expeced ')'")))
+    return failure();
+
+  // Only parameter elements are allowed to be parsed under a `ref` directive.
+  return create<RefDirective>(*child);
+}
+
 //===----------------------------------------------------------------------===//
 // Interface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 741e2716f0388..f180f2da48e8d 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -338,6 +338,22 @@ class CustomDirective : public DirectiveElementBase<DirectiveElement::Custom> {
   std::vector<FormatElement *> arguments;
 };
 
+/// This class represents a reference directive. This directive can be used to
+/// reference but not bind a previously bound variable or format object. Its
+/// current only use is to pass variables as arguments to the custom directive.
+class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
+public:
+  /// Create a reference directive with the single referenced child.
+  RefDirective(FormatElement *arg) : arg(arg) {}
+
+  /// Get the reference argument.
+  FormatElement *getArg() const { return arg; }
+
+private:
+  /// The referenced argument.
+  FormatElement *arg;
+};
+
 /// This class represents a group of elements that are optionally emitted based
 /// on an optional variable "anchor" and a group of elements that are emitted
 /// when the anchor element is not present.

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0d970d82aa3f3..044f1b01c3bd3 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -153,18 +153,6 @@ class FunctionalTypeDirective
   FormatElement *inputs, *results;
 };
 
-/// This class represents the `ref` directive.
-class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
-public:
-  RefDirective(FormatElement *arg) : arg(arg) {}
-
-  FormatElement *getArg() const { return arg; }
-
-private:
-  /// The argument that is used to format the directive.
-  FormatElement *arg;
-};
-
 /// This class represents the `type` directive.
 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
 public:


        


More information about the Mlir-commits mailing list