[Mlir-commits] [mlir] 6e3292b - [mlir][OpFormatGen] Refactor `type_ref` into a more general `ref` directive

River Riddle llvmlistbot at llvm.org
Tue Feb 9 14:41:51 PST 2021


Author: River Riddle
Date: 2021-02-09T14:33:48-08:00
New Revision: 6e3292b0b718a4fec524796e3899a4df5e7ccfb7

URL: https://github.com/llvm/llvm-project/commit/6e3292b0b718a4fec524796e3899a4df5e7ccfb7
DIFF: https://github.com/llvm/llvm-project/commit/6e3292b0b718a4fec524796e3899a4df5e7ccfb7.diff

LOG: [mlir][OpFormatGen] Refactor `type_ref` into a more general `ref` directive

This allows for referencing nearly every component of an operation from within a custom directive.

It also fixes a bug with the current type_ref implementation, PR48478

Differential Revision: https://reviews.llvm.org/D96189

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index acc45a5335b0..07d79e0da918 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -613,6 +613,15 @@ The available directives are as follows:
 
     -   Represents all of the operands of an operation.
 
+*   `ref` ( input )
+
+    -   Represents a reference to the a variable or directive, that must have
+        already been resolved, that may be used as a parameter to a `custom`
+        directive.
+    -   Used to pass previously parsed entities to custom directives.
+    -   The input may be any directive or variable, aside from `functional-type`
+        and `custom`.
+
 *   `regions`
 
     -   Represents all of the regions of an operation.
@@ -631,14 +640,6 @@ The available directives are as follows:
     -   `input` must be either an operand or result [variable](#variables), the
         `operands` directive, or the `results` directive.
 
-*   `type_ref` ( input )
-
-    -   Represents a reference to the type of the given input that must have
-        already been resolved.
-    -   `input` must be either an operand or result [variable](#variables), the
-        `operands` directive, or the `results` directive.
-    -   Used to pass previously parsed types to custom directives.
-
 #### Literals
 
 A literal is either a keyword or punctuation surrounded by \`\`.
@@ -716,6 +717,10 @@ declarative parameter to `parse` method argument is detailed below:
     -   Single: `OpAsmParser::OperandType &`
     -   Optional: `Optional<OpAsmParser::OperandType> &`
     -   Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
+*   Ref Directives
+    -   A reference directive is passed to the parser using the same mapping as
+        the input operand. For example, a single region would be passed as a
+        `Region &`.
 *   Region Variables
     -   Single: `Region &`
     -   Variadic: `SmallVectorImpl<std::unique_ptr<Region>> &`
@@ -726,10 +731,6 @@ declarative parameter to `parse` method argument is detailed below:
     -   Single: `Type &`
     -   Optional: `Type &`
     -   Variadic: `SmallVectorImpl<Type> &`
-*   TypeRef Directives
-    -   Single: `Type`
-    -   Optional: `Type`
-    -   Variadic: `const SmallVectorImpl<Type> &`
 *   `attr-dict` Directive: `NamedAttrList &`
 
 When a variable is optional, the value should only be specified if the variable
@@ -748,6 +749,10 @@ declarative parameter to `print` method argument is detailed below:
     -   Single: `Value`
     -   Optional: `Value`
     -   Variadic: `OperandRange`
+*   Ref Directives
+    -   A reference directive is passed to the printer using the same mapping as
+        the input operand. For example, a single region would be passed as a
+        `Region &`.
 *   Region Variables
     -   Single: `Region &`
     -   Variadic: `MutableArrayRef<Region>`
@@ -758,10 +763,6 @@ declarative parameter to `print` method argument is detailed below:
     -   Single: `Type`
     -   Optional: `Type`
     -   Variadic: `TypeRange`
-*   TypeRef Directives
-    -   Single: `Type`
-    -   Optional: `Type`
-    -   Variadic: `TypeRange`
 *   `attr-dict` Directive: `DictionaryAttr`
 
 When a variable is optional, the provided value may be null.

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index b13f4f44f1bd..3ed13bacf66a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -352,6 +352,14 @@ static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
                                                 NamedAttrList &attrs) {
   return parser.parseOptionalAttrDict(attrs);
 }
+static ParseResult parseCustomDirectiveOptionalOperandRef(
+    OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
+  int64_t operandCount = 0;
+  if (parser.parseInteger(operandCount))
+    return failure();
+  bool expectedOptionalOperand = operandCount == 0;
+  return success(expectedOptionalOperand != optOperand.hasValue());
+}
 
 //===----------------------------------------------------------------------===//
 // Printing
@@ -417,6 +425,13 @@ static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
                                          DictionaryAttr attrs) {
   printer.printOptionalAttrDict(attrs.getValue());
 }
+
+static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
+                                                   Operation *op,
+                                                   Value optOperand) {
+  printer << (optOperand ? "1" : "0");
+}
+
 //===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f5df4ac62df2..44df4ab8e10b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1698,12 +1698,22 @@ def FormatCustomDirectiveResultsWithTypeRefs
       type($result), type($optResult), type($varResults)
     )
     custom<CustomDirectiveWithTypeRefs>(
-      type_ref($result), type_ref($optResult), type_ref($varResults)
+      ref(type($result)), ref(type($optResult)), ref(type($varResults))
     )
     attr-dict
   }];
 }
 
+def FormatCustomDirectiveWithOptionalOperandRef
+    : TEST_Op<"format_custom_directive_with_optional_operand_ref"> {
+  let arguments = (ins Optional<I64>:$optOperand);
+  let assemblyFormat = [{
+    ($optOperand^)? `:`
+    custom<CustomDirectiveOptionalOperandRef>(ref($optOperand))
+    attr-dict
+  }];
+}
+
 def FormatCustomDirectiveSuccessors
     : TEST_Op<"format_custom_directive_successors", [Terminator]> {
   let successors = (successor AnySuccessor:$successor,

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 652bbd08679d..4f5ca63c4e72 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -7,8 +7,8 @@ include "mlir/IR/OpBase.td"
 def TestDialect : Dialect {
   let name = "test";
 }
-class TestFormat_Op<string name, string fmt, list<OpTrait> traits = []>
-    : Op<TestDialect, name, traits> {
+class TestFormat_Op<string fmt, list<OpTrait> traits = []>
+    : Op<TestDialect, "format_op", traits> {
   let assemblyFormat = fmt;
 }
 
@@ -20,25 +20,25 @@ class TestFormat_Op<string name, string fmt, list<OpTrait> traits = []>
 // attr-dict
 
 // CHECK: error: 'attr-dict' directive not found
-def DirectiveAttrDictInvalidA : TestFormat_Op<"attrdict_invalid_a", [{
+def DirectiveAttrDictInvalidA : TestFormat_Op<[{
 }]>;
 // CHECK: error: 'attr-dict' directive has already been seen
-def DirectiveAttrDictInvalidB : TestFormat_Op<"attrdict_invalid_b", [{
+def DirectiveAttrDictInvalidB : TestFormat_Op<[{
   attr-dict attr-dict
 }]>;
 // CHECK: error: 'attr-dict' directive has already been seen
-def DirectiveAttrDictInvalidC : TestFormat_Op<"attrdict_invalid_c", [{
+def DirectiveAttrDictInvalidC : TestFormat_Op<[{
   attr-dict attr-dict-with-keyword
 }]>;
 // CHECK: error: 'attr-dict' directive can only be used as a top-level directive
-def DirectiveAttrDictInvalidD : TestFormat_Op<"attrdict_invalid_d", [{
+def DirectiveAttrDictInvalidD : TestFormat_Op<[{
   type(attr-dict)
 }]>;
 // CHECK-NOT: error
-def DirectiveAttrDictValidA : TestFormat_Op<"attrdict_valid_a", [{
+def DirectiveAttrDictValidA : TestFormat_Op<[{
   attr-dict
 }]>;
-def DirectiveAttrDictValidB : TestFormat_Op<"attrdict_valid_b", [{
+def DirectiveAttrDictValidB : TestFormat_Op<[{
   attr-dict-with-keyword
 }]>;
 
@@ -46,42 +46,42 @@ def DirectiveAttrDictValidB : TestFormat_Op<"attrdict_valid_b", [{
 // custom
 
 // CHECK: error: expected '<' before custom directive name
-def DirectiveCustomInvalidA : TestFormat_Op<"custom_invalid_a", [{
+def DirectiveCustomInvalidA : TestFormat_Op<[{
   custom(
 }]>;
 // CHECK: error: expected custom directive name identifier
-def DirectiveCustomInvalidB : TestFormat_Op<"custom_invalid_b", [{
+def DirectiveCustomInvalidB : TestFormat_Op<[{
   custom<>
 }]>;
 // CHECK: error: expected '>' after custom directive name
-def DirectiveCustomInvalidC : TestFormat_Op<"custom_invalid_c", [{
+def DirectiveCustomInvalidC : TestFormat_Op<[{
   custom<MyDirective(
 }]>;
 // CHECK: error: expected '(' before custom directive parameters
-def DirectiveCustomInvalidD : TestFormat_Op<"custom_invalid_d", [{
+def DirectiveCustomInvalidD : TestFormat_Op<[{
   custom<MyDirective>)
 }]>;
 // CHECK: error: only variables and types may be used as parameters to a custom directive
-def DirectiveCustomInvalidE : TestFormat_Op<"custom_invalid_e", [{
+def DirectiveCustomInvalidE : TestFormat_Op<[{
   custom<MyDirective>(operands)
 }]>;
 // CHECK: error: expected ')' after custom directive parameters
-def DirectiveCustomInvalidF : TestFormat_Op<"custom_invalid_f", [{
+def DirectiveCustomInvalidF : TestFormat_Op<[{
   custom<MyDirective>($operand<
 }]>, Arguments<(ins I64:$operand)>;
 // CHECK: error: type directives within a custom directive may only refer to variables
-def DirectiveCustomInvalidH : TestFormat_Op<"custom_invalid_h", [{
+def DirectiveCustomInvalidH : TestFormat_Op<[{
   custom<MyDirective>(type(operands))
 }]>;
 
 // CHECK-NOT: error
-def DirectiveCustomValidA : TestFormat_Op<"custom_valid_a", [{
+def DirectiveCustomValidA : TestFormat_Op<[{
   custom<MyDirective>($operand) attr-dict
 }]>, Arguments<(ins Optional<I64>:$operand)>;
-def DirectiveCustomValidB : TestFormat_Op<"custom_valid_b", [{
+def DirectiveCustomValidB : TestFormat_Op<[{
   custom<MyDirective>($operand, type($operand), type($result)) attr-dict
 }]>, Arguments<(ins I64:$operand)>, Results<(outs I64:$result)>;
-def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{
+def DirectiveCustomValidC : TestFormat_Op<[{
   custom<MyDirective>($attr) attr-dict
 }]>, Arguments<(ins I64Attr:$attr)>;
 
@@ -89,31 +89,31 @@ def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{
 // functional-type
 
 // CHECK: error: 'functional-type' is only valid as a top-level directive
-def DirectiveFunctionalTypeInvalidA : TestFormat_Op<"functype_invalid_a", [{
+def DirectiveFunctionalTypeInvalidA : TestFormat_Op<[{
   functional-type(functional-type)
 }]>;
 // CHECK: error: expected '(' before argument list
-def DirectiveFunctionalTypeInvalidB : TestFormat_Op<"functype_invalid_b", [{
+def DirectiveFunctionalTypeInvalidB : TestFormat_Op<[{
   functional-type
 }]>;
 // CHECK: error: expected directive, literal, variable, or optional group
-def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{
+def DirectiveFunctionalTypeInvalidC : TestFormat_Op<[{
   functional-type(
 }]>;
 // CHECK: error: expected ',' after inputs argument
-def DirectiveFunctionalTypeInvalidD : TestFormat_Op<"functype_invalid_d", [{
+def DirectiveFunctionalTypeInvalidD : TestFormat_Op<[{
   functional-type(operands
 }]>;
 // CHECK: error: expected directive, literal, variable, or optional group
-def DirectiveFunctionalTypeInvalidE : TestFormat_Op<"functype_invalid_e", [{
+def DirectiveFunctionalTypeInvalidE : TestFormat_Op<[{
   functional-type(operands,
 }]>;
 // CHECK: error: expected ')' after argument list
-def DirectiveFunctionalTypeInvalidF : TestFormat_Op<"functype_invalid_f", [{
+def DirectiveFunctionalTypeInvalidF : TestFormat_Op<[{
   functional-type(operands, results
 }]>;
 // CHECK-NOT: error
-def DirectiveFunctionalTypeValid : TestFormat_Op<"functype_invalid_a", [{
+def DirectiveFunctionalTypeValid : TestFormat_Op<[{
   functional-type(operands, results) attr-dict
 }]>;
 
@@ -121,45 +121,128 @@ def DirectiveFunctionalTypeValid : TestFormat_Op<"functype_invalid_a", [{
 // operands
 
 // CHECK: error: 'operands' directive creates overlap in format
-def DirectiveOperandsInvalidA : TestFormat_Op<"operands_invalid_a", [{
+def DirectiveOperandsInvalidA : TestFormat_Op<[{
   operands operands
 }]>;
 // CHECK: error: 'operands' directive creates overlap in format
-def DirectiveOperandsInvalidB : TestFormat_Op<"operands_invalid_b", [{
+def DirectiveOperandsInvalidB : TestFormat_Op<[{
   $operand operands
 }]>, Arguments<(ins I64:$operand)>;
 // CHECK-NOT: error:
-def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
+def DirectiveOperandsValid : TestFormat_Op<[{
   operands attr-dict
 }]>;
 
+//===----------------------------------------------------------------------===//
+// ref
+
+// CHECK: error: 'ref' is only valid within a `custom` directive
+def DirectiveRefInvalidA : TestFormat_Op<[{
+  ref(type($operand))
+}]>, Arguments<(ins I64:$operand)>;
+
+// CHECK: error: 'ref' of 'type($operand)' is not bound by a prior 'type' directive
+def DirectiveRefInvalidB : TestFormat_Op<[{
+  custom<Foo>(ref(type($operand)))
+}]>, Arguments<(ins I64:$operand)>;
+
+// CHECK: error: 'ref' of 'type(operands)' is not bound by a prior 'type' directive
+def DirectiveRefInvalidC : TestFormat_Op<[{
+  custom<Foo>(ref(type(operands)))
+}]>;
+
+// CHECK: error: 'ref' of 'type($result)' is not bound by a prior 'type' directive
+def DirectiveRefInvalidD : TestFormat_Op<[{
+  custom<Foo>(ref(type($result)))
+}]>, Results<(outs I64:$result)>;
+
+// CHECK: error: 'ref' of 'type(results)' is not bound by a prior 'type' directive
+def DirectiveRefInvalidE : TestFormat_Op<[{
+  custom<Foo>(ref(type(results)))
+}]>;
+
+// CHECK: error: 'ref' of 'successors' is not bound by a prior 'successors' directive
+def DirectiveRefInvalidF : TestFormat_Op<[{
+  custom<Foo>(ref(successors))
+}]>;
+
+// CHECK: error: 'ref' of 'regions' is not bound by a prior 'regions' directive
+def DirectiveRefInvalidG : TestFormat_Op<[{
+  custom<Foo>(ref(regions))
+}]>;
+
+// CHECK: error: expected '(' before argument list
+def DirectiveRefInvalidH : TestFormat_Op<[{
+  custom<Foo>(ref)
+}]>;
+
+// CHECK: error: expected ')' after argument list
+def DirectiveRefInvalidI : TestFormat_Op<[{
+  operands custom<Foo>(ref(operands(
+}]>;
+
+// CHECK: error: 'ref' of 'operands' is not bound by a prior 'operands' directive
+def DirectiveRefInvalidJ : TestFormat_Op<[{
+  custom<Foo>(ref(operands))
+}]>;
+
+// CHECK: error: 'ref' of 'attr-dict' is not bound by a prior 'attr-dict' directive
+def DirectiveRefInvalidK : TestFormat_Op<[{
+  custom<Foo>(ref(attr-dict))
+}]>;
+
+// CHECK: error: successor 'successor' must be bound before it is referenced
+def DirectiveRefInvalidL : TestFormat_Op<[{
+  custom<Foo>(ref($successor))
+}]> {
+  let successors = (successor AnySuccessor:$successor);
+}
+
+// CHECK: error: region 'region' must be bound before it is referenced
+def DirectiveRefInvalidM : TestFormat_Op<[{
+  custom<Foo>(ref($region))
+}]> {
+  let regions = (region AnyRegion:$region);
+}
+
+// CHECK: error: attribute 'attr' must be bound before it is referenced
+def DirectiveRefInvalidN : TestFormat_Op<[{
+  custom<Foo>(ref($attr))
+}]>, Arguments<(ins I64Attr:$attr)>;
+
+
+// CHECK: error: operand 'operand' must be bound before it is referenced
+def DirectiveRefInvalidO : TestFormat_Op<[{
+  custom<Foo>(ref($operand))
+}]>, Arguments<(ins I64:$operand)>;
+
 //===----------------------------------------------------------------------===//
 // regions
 
 // CHECK: error: 'regions' directive creates overlap in format
-def DirectiveRegionsInvalidA : TestFormat_Op<"regions_invalid_a", [{
+def DirectiveRegionsInvalidA : TestFormat_Op<[{
   regions regions attr-dict
 }]>;
 // CHECK: error: 'regions' directive creates overlap in format
-def DirectiveRegionsInvalidB : TestFormat_Op<"regions_invalid_b", [{
+def DirectiveRegionsInvalidB : TestFormat_Op<[{
   $region regions attr-dict
 }]> {
   let regions = (region AnyRegion:$region);
 }
 // CHECK: error: 'regions' is only valid as a top-level directive
-def DirectiveRegionsInvalidC : TestFormat_Op<"regions_invalid_c", [{
+def DirectiveRegionsInvalidC : TestFormat_Op<[{
   type(regions)
 }]>;
 // CHECK-NOT: error:
-def DirectiveRegionsValid : TestFormat_Op<"regions_valid", [{
+def DirectiveRegionsValid : TestFormat_Op<[{
   regions attr-dict
 }]>;
 
 //===----------------------------------------------------------------------===//
 // results
 
-// CHECK: error: 'results' directive can not be used as a top-level directive
-def DirectiveResultsInvalidA : TestFormat_Op<"results_invalid_a", [{
+// CHECK: error: 'results' directive can can only be used as a child to a 'type' directive
+def DirectiveResultsInvalidA : TestFormat_Op<[{
   results
 }]>;
 
@@ -167,7 +250,7 @@ def DirectiveResultsInvalidA : TestFormat_Op<"results_invalid_a", [{
 // successors
 
 // CHECK: error: 'successors' is only valid as a top-level directive
-def DirectiveSuccessorsInvalidA : TestFormat_Op<"successors_invalid_a", [{
+def DirectiveSuccessorsInvalidA : TestFormat_Op<[{
   type(successors)
 }]>;
 
@@ -175,140 +258,78 @@ def DirectiveSuccessorsInvalidA : TestFormat_Op<"successors_invalid_a", [{
 // type
 
 // CHECK: error: expected '(' before argument list
-def DirectiveTypeInvalidA : TestFormat_Op<"type_invalid_a", [{
+def DirectiveTypeInvalidA : TestFormat_Op<[{
   type
 }]>;
 // CHECK: error: expected directive, literal, variable, or optional group
-def DirectiveTypeInvalidB : TestFormat_Op<"type_invalid_b", [{
+def DirectiveTypeInvalidB : TestFormat_Op<[{
   type(
 }]>;
 // CHECK: error: expected ')' after argument list
-def DirectiveTypeInvalidC : TestFormat_Op<"type_invalid_c", [{
+def DirectiveTypeInvalidC : TestFormat_Op<[{
   type(operands
 }]>;
 // CHECK-NOT: error:
-def DirectiveTypeValid : TestFormat_Op<"type_valid", [{
+def DirectiveTypeValid : TestFormat_Op<[{
   type(operands) attr-dict
 }]>;
 
 //===----------------------------------------------------------------------===//
 // functional-type/type operands
 
-// CHECK: error: 'type' directive operand expects variable or directive operand
-def DirectiveTypeZOperandInvalidA : TestFormat_Op<"type_operand_invalid_a", [{
+// CHECK: error: literals may only be used in a top-level section of the format
+def DirectiveTypeZOperandInvalidA : TestFormat_Op<[{
   type(`literal`)
 }]>;
 // CHECK: error: 'operands' 'type' is already bound
-def DirectiveTypeZOperandInvalidB : TestFormat_Op<"type_operand_invalid_b", [{
+def DirectiveTypeZOperandInvalidB : TestFormat_Op<[{
   type(operands) type(operands)
 }]>;
 // CHECK: error: 'operands' 'type' is already bound
-def DirectiveTypeZOperandInvalidC : TestFormat_Op<"type_operand_invalid_c", [{
+def DirectiveTypeZOperandInvalidC : TestFormat_Op<[{
   type($operand) type(operands)
 }]>, Arguments<(ins I64:$operand)>;
 // CHECK: error: 'type' of 'operand' is already bound
-def DirectiveTypeZOperandInvalidD : TestFormat_Op<"type_operand_invalid_d", [{
+def DirectiveTypeZOperandInvalidD : TestFormat_Op<[{
   type(operands) type($operand)
 }]>, Arguments<(ins I64:$operand)>;
 // CHECK: error: 'type' of 'operand' is already bound
-def DirectiveTypeZOperandInvalidE : TestFormat_Op<"type_operand_invalid_e", [{
+def DirectiveTypeZOperandInvalidE : TestFormat_Op<[{
   type($operand) type($operand)
 }]>, Arguments<(ins I64:$operand)>;
 // CHECK: error: 'results' 'type' is already bound
-def DirectiveTypeZOperandInvalidF : TestFormat_Op<"type_operand_invalid_f", [{
+def DirectiveTypeZOperandInvalidF : TestFormat_Op<[{
   type(results) type(results)
 }]>;
 // CHECK: error: 'results' 'type' is already bound
-def DirectiveTypeZOperandInvalidG : TestFormat_Op<"type_operand_invalid_g", [{
+def DirectiveTypeZOperandInvalidG : TestFormat_Op<[{
   type($result) type(results)
 }]>, Results<(outs I64:$result)>;
 // CHECK: error: 'type' of 'result' is already bound
-def DirectiveTypeZOperandInvalidH : TestFormat_Op<"type_operand_invalid_h", [{
+def DirectiveTypeZOperandInvalidH : TestFormat_Op<[{
   type(results) type($result)
 }]>, Results<(outs I64:$result)>;
 // CHECK: error: 'type' of 'result' is already bound
-def DirectiveTypeZOperandInvalidI : TestFormat_Op<"type_operand_invalid_i", [{
+def DirectiveTypeZOperandInvalidI : TestFormat_Op<[{
   type($result) type($result)
 }]>, Results<(outs I64:$result)>;
 
-//===----------------------------------------------------------------------===//
-// type_ref
-
-// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive
-def DirectiveTypeZZTypeRefOperandInvalidC : TestFormat_Op<"type_ref_operand_invalid_c", [{
-  type_ref($operand) type(operands)
-}]>, Arguments<(ins I64:$operand)>;
-// CHECK: error: 'operands' 'type_ref' is not bound by a prior 'type' directive
-def DirectiveTypeZZTypeRefOperandInvalidD : TestFormat_Op<"type_ref_operand_invalid_d", [{
-  type_ref(operands) type($operand)
-}]>, Arguments<(ins I64:$operand)>;
-// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive
-def DirectiveTypeZZTypeRefOperandInvalidE : TestFormat_Op<"type_ref_operand_invalid_e", [{
-  type_ref($operand) type($operand)
-}]>, Arguments<(ins I64:$operand)>;
-// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive
-def DirectiveTypeZZTypeRefOperandInvalidG : TestFormat_Op<"type_ref_operand_invalid_g", [{
-  type_ref($result) type(results)
-}]>, Results<(outs I64:$result)>;
-// CHECK: error: 'results' 'type_ref' is not bound by a prior 'type' directive
-def DirectiveTypeZZTypeRefOperandInvalidH : TestFormat_Op<"type_ref_operand_invalid_h", [{
-  type_ref(results) type($result)
-}]>, Results<(outs I64:$result)>;
-// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive
-def DirectiveTypeZZTypeRefOperandInvalidI : TestFormat_Op<"type_ref_operand_invalid_i", [{
-  type_ref($result) type($result)
-}]>, Results<(outs I64:$result)>;
-
-// CHECK-NOT: error
-def DirectiveTypeZZTypeRefOperandB : TestFormat_Op<"type_ref_operand_valid_b", [{
-  type_ref(operands) attr-dict
-}]>;
-// CHECK-NOT: error
-def DirectiveTypeZZTypeRefOperandD : TestFormat_Op<"type_ref_operand_valid_d", [{
-  type(operands) type_ref($operand) attr-dict
-}]>, Arguments<(ins I64:$operand)>;
-// CHECK-NOT: error
-def DirectiveTypeZZTypeRefOperandE : TestFormat_Op<"type_ref_operand_valid_e", [{
-  type($operand) type_ref($operand) attr-dict
-}]>, Arguments<(ins I64:$operand)>;
-// CHECK-NOT: error
-def DirectiveTypeZZTypeRefOperandF : TestFormat_Op<"type_ref_operand_valid_f", [{
-  type(results) type_ref(results) attr-dict
-}]>;
-// CHECK-NOT: error
-def DirectiveTypeZZTypeRefOperandG : TestFormat_Op<"type_ref_operand_valid_g", [{
-  type($result) type_ref(results) attr-dict
-}]>, Results<(outs I64:$result)>;
-// CHECK-NOT: error
-def DirectiveTypeZZTypeRefOperandH : TestFormat_Op<"type_ref_operand_valid_h", [{
-  type(results) type_ref($result) attr-dict
-}]>, Results<(outs I64:$result)>;
-// CHECK-NOT: error
-def DirectiveTypeZZTypeRefOperandI : TestFormat_Op<"type_ref_operand_valid_i", [{
-  type($result) type_ref($result) attr-dict
-}]>, Results<(outs I64:$result)>;
-
-// CHECK-NOT: error:
-def DirectiveTypeZZZOperandValid : TestFormat_Op<"type_operand_valid", [{
-  type(operands) type(results) attr-dict
-}]>;
-
 //===----------------------------------------------------------------------===//
 // Literals
 //===----------------------------------------------------------------------===//
 
 // Test all of the valid literals.
 // CHECK: error: expected valid literal
-def LiteralInvalidA : TestFormat_Op<"literal_invalid_a", [{
+def LiteralInvalidA : TestFormat_Op<[{
   `1`
 }]>;
 // CHECK: error: unexpected end of file in literal
 // CHECK: error: expected directive, literal, variable, or optional group
-def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{
+def LiteralInvalidB : TestFormat_Op<[{
   `
 }]>;
 // CHECK-NOT: error
-def LiteralValid : TestFormat_Op<"literal_valid", [{
+def LiteralValid : TestFormat_Op<[{
   `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `\n` `abc$._`
   attr-dict
 }]>;
@@ -318,60 +339,60 @@ def LiteralValid : TestFormat_Op<"literal_valid", [{
 //===----------------------------------------------------------------------===//
 
 // CHECK: error: optional groups can only be used as top-level elements
-def OptionalInvalidA : TestFormat_Op<"optional_invalid_a", [{
+def OptionalInvalidA : TestFormat_Op<[{
   type(($attr^)?) attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
 // CHECK: error: expected directive, literal, variable, or optional group
-def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
+def OptionalInvalidB : TestFormat_Op<[{
   () attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
 // CHECK: error: optional group specified no anchor element
-def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
+def OptionalInvalidC : TestFormat_Op<[{
   ($attr)? attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
 // CHECK: error: first parsable element of an operand group must be an attribute, literal, operand, or region
-def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
+def OptionalInvalidD : TestFormat_Op<[{
   (type($operand) $operand^)? attr-dict
 }]>, Arguments<(ins Optional<I64>:$operand)>;
 // CHECK: error: only literals, types, and variables can be used within an optional group
-def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
+def OptionalInvalidE : TestFormat_Op<[{
   (`,` $attr^ type(operands))? attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
 // CHECK: error: only one element can be marked as the anchor of an optional group
-def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{
+def OptionalInvalidF : TestFormat_Op<[{
   ($attr^ $attr2^) attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr, OptionalAttr<I64Attr>:$attr2)>;
 // CHECK: error: only optional attributes can be used to anchor an optional group
-def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
+def OptionalInvalidG : TestFormat_Op<[{
   ($attr^) attr-dict
 }]>, Arguments<(ins I64Attr:$attr)>;
 // CHECK: error: only variable length operands can be used within an optional group
-def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
+def OptionalInvalidH : TestFormat_Op<[{
   ($arg^) attr-dict
 }]>, Arguments<(ins I64:$arg)>;
 // CHECK: error: only literals, types, and variables can be used within an optional group
-def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{
+def OptionalInvalidI : TestFormat_Op<[{
   (functional-type($arg, results)^)? attr-dict
 }]>, Arguments<(ins Variadic<I64>:$arg)>;
 // CHECK: error: only literals, types, and variables can be used within an optional group
-def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
+def OptionalInvalidJ : TestFormat_Op<[{
   (attr-dict)
 }]>;
 // CHECK: error: expected '?' after optional group
-def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
+def OptionalInvalidK : TestFormat_Op<[{
   ($arg^)
 }]>, Arguments<(ins Variadic<I64>:$arg)>;
 // CHECK: error: only variables and types can be used to anchor an optional group
-def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
+def OptionalInvalidL : TestFormat_Op<[{
   (custom<MyDirective>($arg)^)?
 }]>, Arguments<(ins I64:$arg)>;
 // CHECK: error: only variables and types can be used to anchor an optional group
-def OptionalInvalidM : TestFormat_Op<"optional_invalid_m", [{
+def OptionalInvalidM : TestFormat_Op<[{
   (` `^)?
 }]>, Arguments<(ins)>;
 
 // CHECK-NOT: error
-def OptionalValidA : TestFormat_Op<"optional_valid_a", [{
+def OptionalValidA : TestFormat_Op<[{
   (` ` `` $arg^)?
 }]>;
 
@@ -380,78 +401,78 @@ def OptionalValidA : TestFormat_Op<"optional_valid_a", [{
 //===----------------------------------------------------------------------===//
 
 // CHECK: error: expected variable to refer to an argument, region, result, or successor
-def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
+def VariableInvalidA : TestFormat_Op<[{
   $unknown_arg attr-dict
 }]>;
 // CHECK: error: attribute 'attr' is already bound
-def VariableInvalidB : TestFormat_Op<"variable_invalid_b", [{
+def VariableInvalidB : TestFormat_Op<[{
   $attr $attr attr-dict
 }]>, Arguments<(ins I64Attr:$attr)>;
 // CHECK: error: operand 'operand' is already bound
-def VariableInvalidC : TestFormat_Op<"variable_invalid_c", [{
+def VariableInvalidC : TestFormat_Op<[{
   $operand $operand attr-dict
 }]>, Arguments<(ins I64:$operand)>;
 // CHECK: error: operand 'operand' is already bound
-def VariableInvalidD : TestFormat_Op<"variable_invalid_d", [{
+def VariableInvalidD : TestFormat_Op<[{
   operands $operand attr-dict
 }]>, Arguments<(ins I64:$operand)>;
-// CHECK: error: results can not be used at the top level
-def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{
+// CHECK: error: result variables can can only be used as a child to a 'type' directive
+def VariableInvalidE : TestFormat_Op<[{
   $result attr-dict
 }]>, Results<(outs I64:$result)>;
 // CHECK: error: successor 'successor' is already bound
-def VariableInvalidF : TestFormat_Op<"variable_invalid_f", [{
+def VariableInvalidF : TestFormat_Op<[{
   $successor $successor attr-dict
 }]> {
   let successors = (successor AnySuccessor:$successor);
 }
 // CHECK: error: successor 'successor' is already bound
-def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
+def VariableInvalidG : TestFormat_Op<[{
   successors $successor attr-dict
 }]> {
   let successors = (successor AnySuccessor:$successor);
 }
 // CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
-def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
+def VariableInvalidH : TestFormat_Op<[{
   $attr `:` attr-dict
 }]>, Arguments<(ins ElementsAttr:$attr)>;
 // CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
-def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
+def VariableInvalidI : TestFormat_Op<[{
   (`foo` $attr^)? `:` attr-dict
 }]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
 // CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
-def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
+def VariableInvalidJ : TestFormat_Op<[{
   $attr ` ` `:` attr-dict
 }]>, Arguments<(ins ElementsAttr:$attr)>;
 // CHECK: error: region 'region' is already bound
-def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
+def VariableInvalidK : TestFormat_Op<[{
   $region $region attr-dict
 }]> {
   let regions = (region AnyRegion:$region);
 }
 // CHECK: error: region 'region' is already bound
-def VariableInvalidL : TestFormat_Op<"variable_invalid_l", [{
+def VariableInvalidL : TestFormat_Op<[{
   regions $region attr-dict
 }]> {
   let regions = (region AnyRegion:$region);
 }
 // CHECK: error: regions can only be used at the top level
-def VariableInvalidM : TestFormat_Op<"variable_invalid_m", [{
+def VariableInvalidM : TestFormat_Op<[{
   type($region)
 }]> {
   let regions = (region AnyRegion:$region);
 }
 // CHECK: error: region #0, named 'region', not found
-def VariableInvalidN : TestFormat_Op<"variable_invalid_n", [{
+def VariableInvalidN : TestFormat_Op<[{
   attr-dict
 }]> {
   let regions = (region AnyRegion:$region);
 }
 // CHECK-NOT: error:
-def VariableValidA : TestFormat_Op<"variable_valid_a", [{
+def VariableValidA : TestFormat_Op<[{
   $attr `:` attr-dict
 }]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
-def VariableValidB : TestFormat_Op<"variable_valid_b", [{
+def VariableValidB : TestFormat_Op<[{
   (`foo` $attr^)? `:` attr-dict
 }]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
 
@@ -461,75 +482,75 @@ def VariableValidB : TestFormat_Op<"variable_valid_b", [{
 
 // CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred
 // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format
-def ZCoverageInvalidA : TestFormat_Op<"variable_invalid_a", [{
+def ZCoverageInvalidA : TestFormat_Op<[{
   attr-dict
 }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
 // CHECK: error: operand #0, named 'operand', not found
 // CHECK: note: suggest adding a '$operand' directive to the custom assembly format
-def ZCoverageInvalidB : TestFormat_Op<"variable_invalid_b", [{
+def ZCoverageInvalidB : TestFormat_Op<[{
   type($result) attr-dict
 }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
 // CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred
 // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format
-def ZCoverageInvalidC : TestFormat_Op<"variable_invalid_c", [{
+def ZCoverageInvalidC : TestFormat_Op<[{
   $operand type($result) attr-dict
 }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
 // CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred
 // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format
-def ZCoverageInvalidD : TestFormat_Op<"variable_invalid_d", [{
+def ZCoverageInvalidD : TestFormat_Op<[{
   operands attr-dict
 }]>, Arguments<(ins Variadic<I64>:$operand)>;
 // CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred
 // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format
-def ZCoverageInvalidE : TestFormat_Op<"variable_invalid_e", [{
+def ZCoverageInvalidE : TestFormat_Op<[{
   attr-dict
 }]>, Results<(outs Variadic<I64>:$result)>;
 // CHECK: error: successor #0, named 'successor', not found
 // CHECK: note: suggest adding a '$successor' directive to the custom assembly format
-def ZCoverageInvalidF : TestFormat_Op<"variable_invalid_f", [{
+def ZCoverageInvalidF : TestFormat_Op<[{
 	 attr-dict
 }]> {
   let successors = (successor AnySuccessor:$successor);
 }
 // CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred
 // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format
-def ZCoverageInvalidG : TestFormat_Op<"variable_invalid_g", [{
+def ZCoverageInvalidG : TestFormat_Op<[{
   operands attr-dict
 }]>, Arguments<(ins Optional<I64>:$operand)>;
 // CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred
 // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format
-def ZCoverageInvalidH : TestFormat_Op<"variable_invalid_h", [{
+def ZCoverageInvalidH : TestFormat_Op<[{
   attr-dict
 }]>, Results<(outs Optional<I64>:$result)>;
 
 // CHECK-NOT: error
-def ZCoverageValidA : TestFormat_Op<"variable_valid_a", [{
+def ZCoverageValidA : TestFormat_Op<[{
   $operand type($operand) type($result) attr-dict
 }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
-def ZCoverageValidB : TestFormat_Op<"variable_valid_b", [{
+def ZCoverageValidB : TestFormat_Op<[{
   $operand type(operands) type(results) attr-dict
 }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
-def ZCoverageValidC : TestFormat_Op<"variable_valid_c", [{
+def ZCoverageValidC : TestFormat_Op<[{
   operands functional-type(operands, results) attr-dict
 }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
 
 // Check that we can infer type equalities from certain traits.
-def ZCoverageValidD : TestFormat_Op<"variable_valid_d", [{
+def ZCoverageValidD : TestFormat_Op<[{
   operands type($result) attr-dict
 }], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
      Results<(outs AnyMemRef:$result)>;
-def ZCoverageValidE : TestFormat_Op<"variable_valid_e", [{
+def ZCoverageValidE : TestFormat_Op<[{
   $operand type($operand) attr-dict
 }], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
      Results<(outs AnyMemRef:$result)>;
-def ZCoverageValidF : TestFormat_Op<"variable_valid_f", [{
+def ZCoverageValidF : TestFormat_Op<[{
   operands type($other) attr-dict
 }], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;
-def ZCoverageValidG : TestFormat_Op<"variable_valid_g", [{
+def ZCoverageValidG : TestFormat_Op<[{
   operands type($other) attr-dict
 }], [AllTypesMatch<["operand", "other"]>]>,
      Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;
-def ZCoverageValidH : TestFormat_Op<"variable_valid_h", [{
+def ZCoverageValidH : TestFormat_Op<[{
   operands type($result) attr-dict
 }], [AllTypesMatch<["operand", "result"]>]>,
      Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index b751b3f3d715..8043786faf08 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -291,6 +291,12 @@ test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_ref
 // CHECK: test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64)
 test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64)
 
+// CHECK: test.format_custom_directive_with_optional_operand_ref %[[I64]] : 1
+test.format_custom_directive_with_optional_operand_ref %i64 : 1
+
+// CHECK: test.format_custom_directive_with_optional_operand_ref : 0
+test.format_custom_directive_with_optional_operand_ref : 0
+
 func @foo() {
   // CHECK: test.format_custom_directive_successors ^bb1, ^bb2
   test.format_custom_directive_successors ^bb1, ^bb2

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 33c0c24f57c0..270db729ec55 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -58,11 +58,11 @@ class Element {
     CustomDirective,
     FunctionalTypeDirective,
     OperandsDirective,
+    RefDirective,
     RegionsDirective,
     ResultsDirective,
     SuccessorsDirective,
     TypeDirective,
-    TypeRefDirective,
 
     /// This element is a literal.
     Literal,
@@ -234,10 +234,10 @@ class FunctionalTypeDirective
   std::unique_ptr<Element> inputs, results;
 };
 
-/// This class represents the `type` directive.
-class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
+/// This class represents the `ref` directive.
+class RefDirective : public DirectiveElement<Element::Kind::RefDirective> {
 public:
-  TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
+  RefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
   Element *getOperand() const { return operand.get(); }
 
 private:
@@ -245,11 +245,10 @@ class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
   std::unique_ptr<Element> operand;
 };
 
-/// This class represents the `type_ref` directive.
-class TypeRefDirective
-    : public DirectiveElement<Element::Kind::TypeRefDirective> {
+/// This class represents the `type` directive.
+class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
 public:
-  TypeRefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
+  TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
   Element *getOperand() const { return operand.get(); }
 
 private:
@@ -873,19 +872,6 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
            << llvm::formatv(
                   "  ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
                   name);
-  } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
-    ArgumentLengthKind lengthKind;
-    StringRef name = getTypeListName(dir->getOperand(), lengthKind);
-    // Refer to the previously encountered TypeDirective for name.
-    // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration
-    // to properly track the types that will be parsed and pushed later on.
-    if (lengthKind != ArgumentLengthKind::Single)
-      body << "  const ::mlir::SmallVector<::mlir::Type, 1> &" << name
-           << "TypesRef(" << name << "Types);\n";
-    else
-      body << llvm::formatv(
-          "  ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n",
-          name);
   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
     ArgumentLengthKind ignored;
     body << "  ::llvm::ArrayRef<::mlir::Type> "
@@ -897,7 +883,6 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
 
 /// Generate the parser for a parameter to a custom directive.
 static void genCustomParameterParser(Element &param, OpMethodBody &body) {
-  body << ", ";
   if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
     body << attr->getVar()->name << "Attr";
   } else if (isa<AttrDictDirective>(&param)) {
@@ -926,15 +911,9 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
     else
       body << llvm::formatv("{0}Successor", name);
 
-  } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
-    ArgumentLengthKind lengthKind;
-    StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
-    if (lengthKind == ArgumentLengthKind::Variadic)
-      body << llvm::formatv("{0}TypesRef", listName);
-    else if (lengthKind == ArgumentLengthKind::Optional)
-      body << llvm::formatv("{0}TypeRef", listName);
-    else
-      body << formatv("{0}RawTypesRef[0]", listName);
+  } else if (auto *dir = dyn_cast<RefDirective>(&param)) {
+    genCustomParameterParser(*dir->getOperand(), body);
+
   } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
     ArgumentLengthKind lengthKind;
     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -967,27 +946,39 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
             "{0}Operand;\n",
             operand->getVar()->name);
       }
-    } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
-      // Reference to an optional which may or may not have been set.
-      // Retrieve from vector if not empty.
-      ArgumentLengthKind lengthKind;
-      StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
-      if (lengthKind == ArgumentLengthKind::Optional)
-        body << llvm::formatv(
-            "    ::mlir::Type {0}TypeRef = {0}TypesRef.empty() "
-            "? Type() : {0}TypesRef[0];\n",
-            listName);
     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
       ArgumentLengthKind lengthKind;
       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
       if (lengthKind == ArgumentLengthKind::Optional)
         body << llvm::formatv("    ::mlir::Type {0}Type;\n", listName);
+    } else if (auto *dir = dyn_cast<RefDirective>(&param)) {
+      Element *input = dir->getOperand();
+      if (auto *operand = dyn_cast<OperandVariable>(input)) {
+        if (!operand->getVar()->isOptional())
+          continue;
+        body << llvm::formatv(
+            "    {0} {1}Operand = {1}Operands.empty() ? {0}() : "
+            "{1}Operands[0];\n",
+            "llvm::Optional<::mlir::OpAsmParser::OperandType>",
+            operand->getVar()->name);
+
+      } else if (auto *type = dyn_cast<TypeDirective>(input)) {
+        ArgumentLengthKind lengthKind;
+        StringRef listName = getTypeListName(type->getOperand(), lengthKind);
+        if (lengthKind == ArgumentLengthKind::Optional) {
+          body << llvm::formatv("    ::mlir::Type {0}Type = {0}Types.empty() ? "
+                                "::mlir::Type() : {0}Types[0];\n",
+                                listName);
+        }
+      }
     }
   }
 
   body << "    if (parse" << dir->getName() << "(parser";
-  for (Element &param : dir->getArguments())
+  for (Element &param : dir->getArguments()) {
+    body << ", ";
     genCustomParameterParser(param, body);
+  }
 
   body << "))\n"
        << "      return ::mlir::failure();\n";
@@ -1008,9 +999,6 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
       body << llvm::formatv("    if ({0}Operand.hasValue())\n"
                             "      {0}Operands.push_back(*{0}Operand);\n",
                             var->name);
-    } else if (isa<TypeRefDirective>(&param)) {
-      // In the `type_ref` case, do not parse a new Type that needs to be added.
-      // Just do nothing here.
     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
       ArgumentLengthKind lengthKind;
       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -1238,15 +1226,6 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
   } else if (isa<SuccessorsDirective>(element)) {
     body << llvm::formatv(successorListParserCode, "full");
 
-  } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
-    ArgumentLengthKind lengthKind;
-    StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
-    if (lengthKind == ArgumentLengthKind::Variadic)
-      body << llvm::formatv(variadicTypeParserCode, listName);
-    else if (lengthKind == ArgumentLengthKind::Optional)
-      body << llvm::formatv(optionalTypeParserCode, listName);
-    else
-      body << formatv(typeParserCode, listName);
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     ArgumentLengthKind lengthKind;
     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -1587,54 +1566,51 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
   shouldEmitSpace = false;
 }
 
+/// Generate the printer for a custom directive parameter.
+static void genCustomDirectiveParameterPrinter(Element *element,
+                                               OpMethodBody &body) {
+  if (auto *attr = dyn_cast<AttributeVariable>(element)) {
+    body << attr->getVar()->name << "Attr()";
+
+  } else if (isa<AttrDictDirective>(element)) {
+    body << "getOperation()->getAttrDictionary()";
+
+  } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
+    body << operand->getVar()->name << "()";
+
+  } else if (auto *region = dyn_cast<RegionVariable>(element)) {
+    body << region->getVar()->name << "()";
+
+  } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+    body << successor->getVar()->name << "()";
+
+  } else if (auto *dir = dyn_cast<RefDirective>(element)) {
+    genCustomDirectiveParameterPrinter(dir->getOperand(), body);
+
+  } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
+    auto *typeOperand = dir->getOperand();
+    auto *operand = dyn_cast<OperandVariable>(typeOperand);
+    auto *var = operand ? operand->getVar()
+                        : cast<ResultVariable>(typeOperand)->getVar();
+    if (var->isVariadic())
+      body << var->name << "().getTypes()";
+    else if (var->isOptional())
+      body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
+    else
+      body << var->name << "().getType()";
+  } else {
+    llvm_unreachable("unknown custom directive parameter");
+  }
+}
+
 /// Generate the printer for a custom directive.
 static void genCustomDirectivePrinter(CustomDirective *customDir,
                                       OpMethodBody &body) {
   body << "  print" << customDir->getName() << "(p, *this";
   for (Element &param : customDir->getArguments()) {
     body << ", ";
-    if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
-      body << attr->getVar()->name << "Attr()";
-
-    } else if (isa<AttrDictDirective>(&param)) {
-      body << "getOperation()->getAttrDictionary()";
-
-    } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
-      body << operand->getVar()->name << "()";
-
-    } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
-      body << region->getVar()->name << "()";
-
-    } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
-      body << successor->getVar()->name << "()";
-
-    } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
-      auto *typeOperand = dir->getOperand();
-      auto *operand = dyn_cast<OperandVariable>(typeOperand);
-      auto *var = operand ? operand->getVar()
-                          : cast<ResultVariable>(typeOperand)->getVar();
-      if (var->isVariadic())
-        body << var->name << "().getTypes()";
-      else if (var->isOptional())
-        body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
-      else
-        body << var->name << "().getType()";
-    } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
-      auto *typeOperand = dir->getOperand();
-      auto *operand = dyn_cast<OperandVariable>(typeOperand);
-      auto *var = operand ? operand->getVar()
-                          : cast<ResultVariable>(typeOperand)->getVar();
-      if (var->isVariadic())
-        body << var->name << "().getTypes()";
-      else if (var->isOptional())
-        body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
-      else
-        body << var->name << "().getType()";
-    } else {
-      llvm_unreachable("unknown custom directive parameter");
-    }
+    genCustomDirectiveParameterPrinter(&param, body);
   }
-
   body << ");\n";
 }
 
@@ -1886,9 +1862,6 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     body << "  p << ";
     genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
-  } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
-    body << "  p << ";
-    genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
     body << "  p.printFunctionalType(";
     genTypeOperandPrinter(dir->getInputs(), body) << ", ";
@@ -1951,11 +1924,11 @@ class Token {
     kw_custom,
     kw_functional_type,
     kw_operands,
+    kw_ref,
     kw_regions,
     kw_results,
     kw_successors,
     kw_type,
-    kw_type_ref,
     keyword_end,
 
     // String valued tokens.
@@ -2156,11 +2129,11 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
           .Case("custom", Token::kw_custom)
           .Case("functional-type", Token::kw_functional_type)
           .Case("operands", Token::kw_operands)
+          .Case("ref", Token::kw_ref)
           .Case("regions", Token::kw_regions)
           .Case("results", Token::kw_results)
           .Case("successors", Token::kw_successors)
           .Case("type", Token::kw_type)
-          .Case("type_ref", Token::kw_type_ref)
           .Default(Token::identifier);
   return Token(kind, str);
 }
@@ -2191,6 +2164,19 @@ class FormatParser {
   LogicalResult parse();
 
 private:
+  /// The current context of the parser when parsing an element.
+  enum ParserContext {
+    /// The element is being parsed in a "top-level" context, i.e. at the top of
+    /// the format or in an optional group.
+    TopLevelContext,
+    /// The element is being parsed as a custom directive child.
+    CustomDirectiveContext,
+    /// The element is being parsed as a type directive child.
+    TypeDirectiveContext,
+    /// The element is being parsed as a reference directive child.
+    RefDirectiveContext
+  };
+
   /// This struct represents a type resolution instance. It includes a specific
   /// type as well as an optional transformer to apply to that type in order to
   /// properly resolve the type of a variable.
@@ -2249,14 +2235,15 @@ class FormatParser {
 
   /// Parse a specific element.
   LogicalResult parseElement(std::unique_ptr<Element> &element,
-                             bool isTopLevel);
+                             ParserContext context);
   LogicalResult parseVariable(std::unique_ptr<Element> &element,
-                              bool isTopLevel);
+                              ParserContext context);
   LogicalResult parseDirective(std::unique_ptr<Element> &element,
-                               bool isTopLevel);
-  LogicalResult parseLiteral(std::unique_ptr<Element> &element);
+                               ParserContext context);
+  LogicalResult parseLiteral(std::unique_ptr<Element> &element,
+                             ParserContext context);
   LogicalResult parseOptional(std::unique_ptr<Element> &element,
-                              bool isTopLevel);
+                              ParserContext context);
   LogicalResult parseOptionalChildElement(
       std::vector<std::unique_ptr<Element>> &childElements,
       Optional<unsigned> &anchorIdx);
@@ -2265,26 +2252,29 @@ class FormatParser {
 
   /// Parse the various 
diff erent directives.
   LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
-                                       llvm::SMLoc loc, bool isTopLevel,
+                                       llvm::SMLoc loc, ParserContext context,
                                        bool withKeyword);
   LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
-                                     llvm::SMLoc loc, bool isTopLevel);
+                                     llvm::SMLoc loc, ParserContext context);
   LogicalResult parseCustomDirectiveParameter(
       std::vector<std::unique_ptr<Element>> &parameters);
   LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
-                                             Token tok, bool isTopLevel);
+                                             Token tok, ParserContext context);
   LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
-                                       llvm::SMLoc loc, bool isTopLevel);
+                                       llvm::SMLoc loc, ParserContext context);
+  LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
+                                        llvm::SMLoc loc, ParserContext context);
   LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
-                                      llvm::SMLoc loc, bool isTopLevel);
+                                      llvm::SMLoc loc, ParserContext context);
   LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
-                                      llvm::SMLoc loc, bool isTopLevel);
+                                      llvm::SMLoc loc, ParserContext context);
   LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
-                                         llvm::SMLoc loc, bool isTopLevel);
+                                         llvm::SMLoc loc,
+                                         ParserContext context);
   LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
-                                   bool isTopLevel, bool isTypeRef = false);
+                                   ParserContext context);
   LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
-                                          bool isTypeRef = false);
+                                          bool isRefChild = false);
 
   //===--------------------------------------------------------------------===//
   // Lexer Utilities
@@ -2340,7 +2330,7 @@ LogicalResult FormatParser::parse() {
   // Parse each of the format elements into the main format.
   while (curToken.getKind() != Token::eof) {
     std::unique_ptr<Element> element;
-    if (failed(parseElement(element, /*isTopLevel=*/true)))
+    if (failed(parseElement(element, TopLevelContext)))
       return ::mlir::failure();
     fmt.elements.push_back(std::move(element));
   }
@@ -2634,25 +2624,25 @@ ConstArgument FormatParser::findSeenArg(StringRef name) {
 }
 
 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
-                                         bool isTopLevel) {
+                                         ParserContext context) {
   // Directives.
   if (curToken.isKeyword())
-    return parseDirective(element, isTopLevel);
+    return parseDirective(element, context);
   // Literals.
   if (curToken.getKind() == Token::literal)
-    return parseLiteral(element);
+    return parseLiteral(element, context);
   // Optionals.
   if (curToken.getKind() == Token::l_paren)
-    return parseOptional(element, isTopLevel);
+    return parseOptional(element, context);
   // Variables.
   if (curToken.getKind() == Token::variable)
-    return parseVariable(element, isTopLevel);
+    return parseVariable(element, context);
   return emitError(curToken.getLoc(),
                    "expected directive, literal, variable, or optional group");
 }
 
 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
-                                          bool isTopLevel) {
+                                          ParserContext context) {
   Token varTok = curToken;
   consumeToken();
 
@@ -2663,42 +2653,67 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
   // op.
   /// Attributes
   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
-    if (isTopLevel && !seenAttrs.insert(attr))
+    if (context == TypeDirectiveContext)
+      return emitError(
+          loc, "attributes cannot be used as children to a `type` directive");
+    if (context == RefDirectiveContext) {
+      if (!seenAttrs.count(attr))
+        return emitError(loc, "attribute '" + name +
+                                  "' must be bound before it is referenced");
+    } else if (!seenAttrs.insert(attr)) {
       return emitError(loc, "attribute '" + name + "' is already bound");
+    }
+
     element = std::make_unique<AttributeVariable>(attr);
     return ::mlir::success();
   }
   /// Operands
   if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
-    if (isTopLevel) {
+    if (context == TopLevelContext || context == CustomDirectiveContext) {
       if (fmt.allOperands || !seenOperands.insert(operand).second)
         return emitError(loc, "operand '" + name + "' is already bound");
+    } else if (context == RefDirectiveContext && !seenOperands.count(operand)) {
+      return emitError(loc, "operand '" + name +
+                                "' must be bound before it is referenced");
     }
     element = std::make_unique<OperandVariable>(operand);
     return ::mlir::success();
   }
   /// Regions
   if (const NamedRegion *region = findArg(op.getRegions(), name)) {
-    if (!isTopLevel)
+    if (context == TopLevelContext || context == CustomDirectiveContext) {
+      if (hasAllRegions || !seenRegions.insert(region).second)
+        return emitError(loc, "region '" + name + "' is already bound");
+    } else if (context == RefDirectiveContext && !seenRegions.count(region)) {
+      return emitError(loc, "region '" + name +
+                                "' must be bound before it is referenced");
+    } else {
       return emitError(loc, "regions can only be used at the top level");
-    if (hasAllRegions || !seenRegions.insert(region).second)
-      return emitError(loc, "region '" + name + "' is already bound");
+    }
     element = std::make_unique<RegionVariable>(region);
     return ::mlir::success();
   }
   /// Results.
   if (const auto *result = findArg(op.getResults(), name)) {
-    if (isTopLevel)
-      return emitError(loc, "results can not be used at the top level");
+    if (context != TypeDirectiveContext)
+      return emitError(loc, "result variables can can only be used as a child "
+                            "to a 'type' directive");
     element = std::make_unique<ResultVariable>(result);
     return ::mlir::success();
   }
   /// Successors.
   if (const auto *successor = findArg(op.getSuccessors(), name)) {
-    if (!isTopLevel)
+    if (context == TopLevelContext || context == CustomDirectiveContext) {
+      if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
+        return emitError(loc, "successor '" + name + "' is already bound");
+    } else if (context == RefDirectiveContext &&
+               !seenSuccessors.count(successor)) {
+      return emitError(loc, "successor '" + name +
+                                "' must be bound before it is referenced");
+    } else {
       return emitError(loc, "successors can only be used at the top level");
-    if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
-      return emitError(loc, "successor '" + name + "' is already bound");
+    }
+
     element = std::make_unique<SuccessorVariable>(successor);
     return ::mlir::success();
   }
@@ -2707,41 +2722,47 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
 }
 
 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
-                                           bool isTopLevel) {
+                                           ParserContext context) {
   Token dirTok = curToken;
   consumeToken();
 
   switch (dirTok.getKind()) {
   case Token::kw_attr_dict:
-    return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
+    return parseAttrDictDirective(element, dirTok.getLoc(), context,
                                   /*withKeyword=*/false);
   case Token::kw_attr_dict_w_keyword:
-    return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
+    return parseAttrDictDirective(element, dirTok.getLoc(), context,
                                   /*withKeyword=*/true);
   case Token::kw_custom:
-    return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
+    return parseCustomDirective(element, dirTok.getLoc(), context);
   case Token::kw_functional_type:
-    return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
+    return parseFunctionalTypeDirective(element, dirTok, context);
   case Token::kw_operands:
-    return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
+    return parseOperandsDirective(element, dirTok.getLoc(), context);
   case Token::kw_regions:
-    return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
+    return parseRegionsDirective(element, dirTok.getLoc(), context);
   case Token::kw_results:
-    return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
+    return parseResultsDirective(element, dirTok.getLoc(), context);
   case Token::kw_successors:
-    return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
-  case Token::kw_type_ref:
-    return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true);
+    return parseSuccessorsDirective(element, dirTok.getLoc(), context);
+  case Token::kw_ref:
+    return parseReferenceDirective(element, dirTok.getLoc(), context);
   case Token::kw_type:
-    return parseTypeDirective(element, dirTok, isTopLevel);
+    return parseTypeDirective(element, dirTok, context);
 
   default:
     llvm_unreachable("unknown directive token");
   }
 }
 
-LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
+LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
+                                         ParserContext context) {
   Token literalTok = curToken;
+  if (context != TopLevelContext) {
+    return emitError(
+        literalTok.getLoc(),
+        "literals may only be used in a top-level section of the format");
+  }
   consumeToken();
 
   StringRef value = literalTok.getSpelling().drop_front().drop_back();
@@ -2766,9 +2787,9 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
 }
 
 LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
-                                          bool isTopLevel) {
+                                          ParserContext context) {
   llvm::SMLoc curLoc = curToken.getLoc();
-  if (!isTopLevel)
+  if (context != TopLevelContext)
     return emitError(curLoc, "optional groups can only be used as top-level "
                              "elements");
   consumeToken();
@@ -2812,7 +2833,7 @@ LogicalResult FormatParser::parseOptionalChildElement(
     Optional<unsigned> &anchorIdx) {
   llvm::SMLoc childLoc = curToken.getLoc();
   childElements.push_back({});
-  if (failed(parseElement(childElements.back(), /*isTopLevel=*/true)))
+  if (failed(parseElement(childElements.back(), TopLevelContext)))
     return ::mlir::failure();
 
   // Check to see if this element is the anchor of the optional group.
@@ -2843,7 +2864,7 @@ LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
       })
       // Only optional-like(i.e. variadic) operands can be within an optional
       // group.
-      .Case<OperandVariable>([&](OperandVariable *ele) {
+      .Case([&](OperandVariable *ele) {
         if (!ele->getVar()->isVariableLength())
           return emitError(childLoc, "only variable length operands can be "
                                      "used within an optional group");
@@ -2851,22 +2872,22 @@ LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
       })
       // Only optional-like(i.e. variadic) results can be within an optional
       // group.
-      .Case<ResultVariable>([&](ResultVariable *ele) {
+      .Case([&](ResultVariable *ele) {
         if (!ele->getVar()->isVariableLength())
           return emitError(childLoc, "only variable length results can be "
                                      "used within an optional group");
         return ::mlir::success();
       })
-      .Case<RegionVariable>([&](RegionVariable *) {
+      .Case([&](RegionVariable *) {
         // TODO: When ODS has proper support for marking "optional" regions, add
         // a check here.
         return ::mlir::success();
       })
-      .Case<TypeDirective>([&](TypeDirective *ele) {
+      .Case([&](TypeDirective *ele) {
         return verifyOptionalChildElement(ele->getOperand(), childLoc,
                                           /*isAnchor=*/false);
       })
-      .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *ele) {
+      .Case([&](FunctionalTypeDirective *ele) {
         if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc,
                                               /*isAnchor=*/false)))
           return failure();
@@ -2876,13 +2897,12 @@ LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
       // Literals, whitespace, and custom directives may be used, but they can't
       // anchor the group.
       .Case<LiteralElement, WhitespaceElement, CustomDirective,
-            FunctionalTypeDirective, OptionalElement, TypeRefDirective>(
-          [&](Element *) {
-            if (isAnchor)
-              return emitError(childLoc, "only variables and types can be used "
-                                         "to anchor an optional group");
-            return ::mlir::success();
-          })
+            FunctionalTypeDirective, OptionalElement>([&](Element *) {
+        if (isAnchor)
+          return emitError(childLoc, "only variables and types can be used "
+                                     "to anchor an optional group");
+        return ::mlir::success();
+      })
       .Default([&](Element *) {
         return emitError(childLoc, "only literals, types, and variables can be "
                                    "used within an optional group");
@@ -2891,23 +2911,34 @@ LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
 
 LogicalResult
 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
-                                     llvm::SMLoc loc, bool isTopLevel,
+                                     llvm::SMLoc loc, ParserContext context,
                                      bool withKeyword) {
-  if (!isTopLevel)
+  if (context == TypeDirectiveContext)
     return emitError(loc, "'attr-dict' directive can only be used as a "
                           "top-level directive");
-  if (hasAttrDict)
-    return emitError(loc, "'attr-dict' directive has already been seen");
 
-  hasAttrDict = true;
+  if (context == RefDirectiveContext) {
+    if (!hasAttrDict)
+      return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
+                            "'attr-dict' directive");
+
+    // Otherwise, this is a top-level context.
+  } else {
+    if (hasAttrDict)
+      return emitError(loc, "'attr-dict' directive has already been seen");
+    hasAttrDict = true;
+  }
+
   element = std::make_unique<AttrDictDirective>(withKeyword);
   return ::mlir::success();
 }
 
 LogicalResult
 FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
-                                   llvm::SMLoc loc, bool isTopLevel) {
+                                   llvm::SMLoc loc, ParserContext context) {
   llvm::SMLoc curLoc = curToken.getLoc();
+  if (context != TopLevelContext)
+    return emitError(loc, "'custom' is only valid as a top-level directive");
 
   // Parse the custom directive name.
   if (failed(
@@ -2940,13 +2971,6 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
   // After parsing all of the elements, ensure that all type directives refer
   // only to variables.
   for (auto &ele : elements) {
-    if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
-      if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
-        return emitError(curLoc,
-                         "type_ref directives within a custom directive "
-                         "may only refer to variables");
-      }
-    }
     if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
       if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
         return emitError(curLoc, "type directives within a custom directive "
@@ -2964,13 +2988,13 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
     std::vector<std::unique_ptr<Element>> &parameters) {
   llvm::SMLoc childLoc = curToken.getLoc();
   parameters.push_back({});
-  if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
+  if (failed(parseElement(parameters.back(), CustomDirectiveContext)))
     return ::mlir::failure();
 
   // Verify that the element can be placed within a custom directive.
-  if (!isa<TypeRefDirective, TypeDirective, AttrDictDirective,
-           AttributeVariable, OperandVariable, RegionVariable,
-           SuccessorVariable>(parameters.back().get())) {
+  if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
+           OperandVariable, RegionVariable, SuccessorVariable>(
+          parameters.back().get())) {
     return emitError(childLoc, "only variables and types may be used as "
                                "parameters to a custom directive");
   }
@@ -2979,9 +3003,9 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
 
 LogicalResult
 FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
-                                           Token tok, bool isTopLevel) {
+                                           Token tok, ParserContext context) {
   llvm::SMLoc loc = tok.getLoc();
-  if (!isTopLevel)
+  if (context != TopLevelContext)
     return emitError(
         loc, "'functional-type' is only valid as a top-level directive");
 
@@ -3000,8 +3024,13 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
 
 LogicalResult
 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
-                                     llvm::SMLoc loc, bool isTopLevel) {
-  if (isTopLevel) {
+                                     llvm::SMLoc loc, ParserContext context) {
+  if (context == RefDirectiveContext) {
+    if (!fmt.allOperands)
+      return emitError(loc, "'ref' of 'operands' is not bound by a prior "
+                            "'operands' directive");
+
+  } else if (context == TopLevelContext || context == CustomDirectiveContext) {
     if (fmt.allOperands || !seenOperands.empty())
       return emitError(loc, "'operands' directive creates overlap in format");
     fmt.allOperands = true;
@@ -3010,65 +3039,96 @@ FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
   return ::mlir::success();
 }
 
+LogicalResult
+FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element,
+                                      llvm::SMLoc loc, ParserContext context) {
+  if (context != CustomDirectiveContext)
+    return emitError(loc, "'ref' is only valid within a `custom` directive");
+
+  std::unique_ptr<Element> operand;
+  if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
+      failed(parseElement(operand, RefDirectiveContext)) ||
+      failed(parseToken(Token::r_paren, "expected ')' after argument list")))
+    return ::mlir::failure();
+
+  element = std::make_unique<RefDirective>(std::move(operand));
+  return ::mlir::success();
+}
+
 LogicalResult
 FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
-                                    llvm::SMLoc loc, bool isTopLevel) {
-  if (!isTopLevel)
+                                    llvm::SMLoc loc, ParserContext context) {
+  if (context == TypeDirectiveContext)
     return emitError(loc, "'regions' is only valid as a top-level directive");
-  if (hasAllRegions || !seenRegions.empty())
-    return emitError(loc, "'regions' directive creates overlap in format");
-  hasAllRegions = true;
+  if (context == RefDirectiveContext) {
+    if (!hasAllRegions)
+      return emitError(loc, "'ref' of 'regions' is not bound by a prior "
+                            "'regions' directive");
+
+    // Otherwise, this is a TopLevel directive.
+  } else {
+    if (hasAllRegions || !seenRegions.empty())
+      return emitError(loc, "'regions' directive creates overlap in format");
+    hasAllRegions = true;
+  }
   element = std::make_unique<RegionsDirective>();
   return ::mlir::success();
 }
 
 LogicalResult
 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
-                                    llvm::SMLoc loc, bool isTopLevel) {
-  if (isTopLevel)
-    return emitError(loc, "'results' directive can not be used as a "
-                          "top-level directive");
+                                    llvm::SMLoc loc, ParserContext context) {
+  if (context != TypeDirectiveContext)
+    return emitError(loc, "'results' directive can can only be used as a child "
+                          "to a 'type' directive");
   element = std::make_unique<ResultsDirective>();
   return ::mlir::success();
 }
 
 LogicalResult
 FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
-                                       llvm::SMLoc loc, bool isTopLevel) {
-  if (!isTopLevel)
+                                       llvm::SMLoc loc, ParserContext context) {
+  if (context == TypeDirectiveContext)
     return emitError(loc,
                      "'successors' is only valid as a top-level directive");
-  if (hasAllSuccessors || !seenSuccessors.empty())
-    return emitError(loc, "'successors' directive creates overlap in format");
-  hasAllSuccessors = true;
+  if (context == RefDirectiveContext) {
+    if (!hasAllSuccessors)
+      return emitError(loc, "'ref' of 'successors' is not bound by a prior "
+                            "'successors' directive");
+
+    // Otherwise, this is a TopLevel directive.
+  } else {
+    if (hasAllSuccessors || !seenSuccessors.empty())
+      return emitError(loc, "'successors' directive creates overlap in format");
+    hasAllSuccessors = true;
+  }
   element = std::make_unique<SuccessorsDirective>();
   return ::mlir::success();
 }
 
 LogicalResult
 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
-                                 bool isTopLevel, bool isTypeRef) {
+                                 ParserContext context) {
   llvm::SMLoc loc = tok.getLoc();
-  if (!isTopLevel)
-    return emitError(loc, "'type' is only valid as a top-level directive");
+  if (context == TypeDirectiveContext)
+    return emitError(loc, "'type' cannot be used as a child of another `type`");
 
+  bool isRefChild = context == RefDirectiveContext;
   std::unique_ptr<Element> operand;
   if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
-      failed(parseTypeDirectiveOperand(operand, isTypeRef)) ||
+      failed(parseTypeDirectiveOperand(operand, isRefChild)) ||
       failed(parseToken(Token::r_paren, "expected ')' after argument list")))
     return ::mlir::failure();
-  if (isTypeRef)
-    element = std::make_unique<TypeRefDirective>(std::move(operand));
-  else
-    element = std::make_unique<TypeDirective>(std::move(operand));
+
+  element = std::make_unique<TypeDirective>(std::move(operand));
   return ::mlir::success();
 }
 
 LogicalResult
 FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
-                                        bool isTypeRef) {
+                                        bool isRefChild) {
   llvm::SMLoc loc = curToken.getLoc();
-  if (failed(parseElement(element, /*isTopLevel=*/false)))
+  if (failed(parseElement(element, TypeDirectiveContext)))
     return ::mlir::failure();
   if (isa<LiteralElement>(element.get()))
     return emitError(
@@ -3076,36 +3136,35 @@ FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
 
   if (auto *var = dyn_cast<OperandVariable>(element.get())) {
     unsigned opIdx = var->getVar() - op.operand_begin();
-    if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
+    if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
       return emitError(loc, "'type' of '" + var->getVar()->name +
                                 "' is already bound");
-    if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
-      return emitError(loc, "'type_ref' of '" + var->getVar()->name +
-                                "' is not bound by a prior 'type' directive");
+    if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
+      return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
+                                ")' is not bound by a prior 'type' directive");
     seenOperandTypes.set(opIdx);
   } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
     unsigned resIdx = var->getVar() - op.result_begin();
-    if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
+    if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
       return emitError(loc, "'type' of '" + var->getVar()->name +
                                 "' is already bound");
-    if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
-      return emitError(loc, "'type_ref' of '" + var->getVar()->name +
-                                "' is not bound by a prior 'type' directive");
+    if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
+      return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
+                                ")' is not bound by a prior 'type' directive");
     seenResultTypes.set(resIdx);
   } else if (isa<OperandsDirective>(&*element)) {
-    if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any()))
+    if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
       return emitError(loc, "'operands' 'type' is already bound");
-    if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all()))
-      return emitError(
-          loc,
-          "'operands' 'type_ref' is not bound by a prior 'type' directive");
+    if (isRefChild && !fmt.allOperandTypes)
+      return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
+                            "'type' directive");
     fmt.allOperandTypes = true;
   } else if (isa<ResultsDirective>(&*element)) {
-    if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any()))
+    if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
       return emitError(loc, "'results' 'type' is already bound");
-    if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all()))
-      return emitError(
-          loc, "'results' 'type_ref' is not bound by a prior 'type' directive");
+    if (isRefChild && !fmt.allResultTypes)
+      return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
+                            "'type' directive");
     fmt.allResultTypes = true;
   } else {
     return emitError(loc, "invalid argument to 'type' directive");


        


More information about the Mlir-commits mailing list