[llvm-branch-commits] [mlir] 29d420e - [mlir][OpFormatGen] Add support for anchoring optional groups with types

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 22 12:15:54 PST 2021


Author: River Riddle
Date: 2021-01-22T12:07:27-08:00
New Revision: 29d420e0bf0273cdef35b2d2453f0f574d1e8313

URL: https://github.com/llvm/llvm-project/commit/29d420e0bf0273cdef35b2d2453f0f574d1e8313
DIFF: https://github.com/llvm/llvm-project/commit/29d420e0bf0273cdef35b2d2453f0f574d1e8313.diff

LOG: [mlir][OpFormatGen] Add support for anchoring optional groups with types

This revision adds support for using either operand or result types to anchor an optional group. It also removes the arbitrary restriction that type directives must refer to variables in the same group, which is overly limiting for a declarative format syntax.

Fixes PR#48784

Differential Revision: https://reviews.llvm.org/D95109

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index dd522904dd73..8a7f6a238732 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -777,8 +777,8 @@ information. An optional group is defined by wrapping a set of elements within
 *   The first element of the group must either be a attribute, literal, operand,
     or region.
     -   This is because the first element must be optionally parsable.
-*   Exactly one argument variable within the group must be marked as the anchor
-    of the group.
+*   Exactly one argument variable or type directive within the group must be
+    marked as the anchor of the group.
     -   The anchor is the element whose presence controls whether the group
         should be printed/parsed.
     -   An element is marked as the anchor by adding a trailing `^`.
@@ -789,11 +789,9 @@ information. An optional group is defined by wrapping a set of elements within
     valid elements within the group.
     -   Any attribute variable may be used, but only optional attributes can be
         marked as the anchor.
-    -   Only variadic or optional operand arguments can be used.
+    -   Only variadic or optional results and operand arguments and can be used.
     -   All region variables can be used. When a non-variable length region is
         used, if the group is not present the region is empty.
-    -   The operands to a type directive must be defined within the optional
-        group.
 
 An example of an operation with an optional group is `std.return`, which has a
 variadic number of operands.

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d1cbe77ac21b..89d2ee87356b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1571,6 +1571,25 @@ def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{
   (`[` $variadic^ `]`)? attr-dict
 }]>;
 
+// Test optional result type formatting.
+class FormatOptionalResultOpBase<string suffix, string fmt>
+    : TEST_Op<"format_optional_result_" # suffix # "_op",
+              [AttrSizedResultSegments]> {
+  let results = (outs Optional<I64>:$optional, Variadic<I64>:$variadic);
+  let assemblyFormat = fmt;
+}
+def FormatOptionalResultAOp : FormatOptionalResultOpBase<"a", [{
+  (`:` type($optional)^ `->` type($variadic))? attr-dict
+}]>;
+
+def FormatOptionalResultBOp : FormatOptionalResultOpBase<"b", [{
+  (`:` type($optional) `->` type($variadic)^)? attr-dict
+}]>;
+
+def FormatOptionalResultCOp : FormatOptionalResultOpBase<"c", [{
+  (`:` functional-type($optional, $variadic)^)? attr-dict
+}]>;
+
 def FormatTwoVariadicOperandsNoBuildableTypeOp
     : TEST_Op<"format_two_variadic_operands_no_buildable_type_op",
               [AttrSizedOperandSegments]> {

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 424dbb83c276..652bbd08679d 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -333,7 +333,7 @@ def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
 def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
   (type($operand) $operand^)? attr-dict
 }]>, Arguments<(ins Optional<I64>:$operand)>;
-// CHECK: error: type directive can only refer to variables within the optional group
+// CHECK: error: only literals, types, and variables can be used within an optional group
 def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
   (`,` $attr^ type(operands))? attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
@@ -349,9 +349,9 @@ def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
 def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
   ($arg^) attr-dict
 }]>, Arguments<(ins I64:$arg)>;
-// CHECK: error: only variables can be used to anchor an optional group
+// CHECK: error: only literals, types, and variables can be used within an optional group
 def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{
-  ($arg type($arg)^) attr-dict
+  (functional-type($arg, results)^)? attr-dict
 }]>, Arguments<(ins Variadic<I64>:$arg)>;
 // CHECK: error: only literals, types, and variables can be used within an optional group
 def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
@@ -361,11 +361,11 @@ def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
 def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
   ($arg^)
 }]>, Arguments<(ins Variadic<I64>:$arg)>;
-// CHECK: error: only variables can be used to anchor an optional group
+// CHECK: error: only variables and types can be used to anchor an optional group
 def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
   (custom<MyDirective>($arg)^)?
 }]>, Arguments<(ins I64:$arg)>;
-// CHECK: error: only variables can be used to anchor an optional group
+// CHECK: error: only variables and types can be used to anchor an optional group
 def OptionalInvalidM : TestFormat_Op<"optional_invalid_m", [{
   (` `^)?
 }]>, Arguments<(ins)>;

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 4ec07998ac7b..b751b3f3d715 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -220,6 +220,25 @@ test.format_optional_operand_result_b_op( : ) : i64
 // CHECK: test.format_optional_operand_result_b_op : i64
 test.format_optional_operand_result_b_op : i64
 
+//===----------------------------------------------------------------------===//
+// Format optional results
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_optional_result_a_op
+test.format_optional_result_a_op
+
+// CHECK: test.format_optional_result_a_op : i64 -> i64, i64
+test.format_optional_result_a_op : i64 -> i64, i64
+
+// CHECK: test.format_optional_result_b_op
+test.format_optional_result_b_op
+
+// CHECK: test.format_optional_result_b_op : i64 -> i64, i64
+test.format_optional_result_b_op : i64 -> i64, i64
+
+// CHECK: test.format_optional_result_c_op : (i64) -> (i64, i64)
+test.format_optional_result_c_op : (i64) -> (i64, i64)
+
 //===----------------------------------------------------------------------===//
 // Format custom directives
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 6010bba255ec..81ac241513b0 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1749,6 +1749,33 @@ static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) {
           "  }\n";
 }
 
+/// Generate the check for the anchor of an optional group.
+static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) {
+  TypeSwitch<Element *>(anchor)
+      .Case<OperandVariable, ResultVariable>([&](auto *element) {
+        const NamedTypeConstraint *var = element->getVar();
+        if (var->isOptional())
+          body << "  if (" << var->name << "()) {\n";
+        else if (var->isVariadic())
+          body << "  if (!" << var->name << "().empty()) {\n";
+      })
+      .Case<RegionVariable>([&](RegionVariable *element) {
+        const NamedRegion *var = element->getVar();
+        // TODO: Add a check for optional regions here when ODS supports it.
+        body << "  if (!" << var->name << "().empty()) {\n";
+      })
+      .Case<TypeDirective>([&](TypeDirective *element) {
+        genOptionalGroupPrinterAnchor(element->getOperand(), body);
+      })
+      .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
+        genOptionalGroupPrinterAnchor(element->getInputs(), body);
+      })
+      .Case<AttributeVariable>([&](AttributeVariable *attr) {
+        body << "  if ((*this)->getAttr(\"" << attr->getVar()->name
+             << "\")) {\n";
+      });
+}
+
 void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
                                         Operator &op, bool &shouldEmitSpace,
                                         bool &lastWasPunctuation) {
@@ -1769,21 +1796,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
     // Emit the check for the presence of the anchor element.
     Element *anchor = optional->getAnchor();
-    if (auto *operand = dyn_cast<OperandVariable>(anchor)) {
-      const NamedTypeConstraint *var = operand->getVar();
-      if (var->isOptional())
-        body << "  if (" << var->name << "()) {\n";
-      else if (var->isVariadic())
-        body << "  if (!" << var->name << "().empty()) {\n";
-    } else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
-      const NamedRegion *var = region->getVar();
-      // TODO: Add a check for optional here when ODS supports it.
-      body << "  if (!" << var->name << "().empty()) {\n";
-
-    } else {
-      body << "  if ((*this)->getAttr(\""
-           << cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
-    }
+    genOptionalGroupPrinterAnchor(anchor, body);
 
     // If the anchor is a unit attribute, we don't need to print it. When
     // parsing, we will add this attribute if this group is present.
@@ -2244,8 +2257,9 @@ class FormatParser {
                               bool isTopLevel);
   LogicalResult parseOptionalChildElement(
       std::vector<std::unique_ptr<Element>> &childElements,
-      SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
       Optional<unsigned> &anchorIdx);
+  LogicalResult verifyOptionalChildElement(Element *element,
+                                           llvm::SMLoc childLoc, bool isAnchor);
 
   /// Parse the various 
diff erent directives.
   LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
@@ -2315,7 +2329,6 @@ class FormatParser {
   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
   llvm::DenseSet<const NamedRegion *> seenRegions;
   llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
-  llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
 };
 } // end anonymous namespace
 
@@ -2760,10 +2773,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
 
   // Parse the child elements for this optional group.
   std::vector<std::unique_ptr<Element>> elements;
-  SmallPtrSet<const NamedTypeConstraint *, 8> seenVariables;
   Optional<unsigned> anchorIdx;
   do {
-    if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx)))
+    if (failed(parseOptionalChildElement(elements, anchorIdx)))
       return ::mlir::failure();
   } while (curToken.getKind() != Token::r_paren);
   consumeToken();
@@ -2787,31 +2799,6 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
                      "first parsable element of an operand group must be "
                      "an attribute, literal, operand, or region");
 
-  // After parsing all of the elements, ensure that all type directives refer
-  // only to elements within the group.
-  auto checkTypeOperand = [&](Element *typeEle) {
-    auto *opVar = dyn_cast<OperandVariable>(typeEle);
-    const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr;
-    if (!seenVariables.count(var))
-      return emitError(curLoc, "type directive can only refer to variables "
-                               "within the optional group");
-    return ::mlir::success();
-  };
-  for (auto &ele : elements) {
-    if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
-      if (failed(checkTypeOperand(typeEle->getOperand())))
-        return failure();
-    } else if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
-      if (failed(checkTypeOperand(typeEle->getOperand())))
-        return ::mlir::failure();
-    } else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
-      if (failed(checkTypeOperand(typeEle->getInputs())) ||
-          failed(checkTypeOperand(typeEle->getResults())))
-        return ::mlir::failure();
-    }
-  }
-
-  optionalVariables.insert(seenVariables.begin(), seenVariables.end());
   auto parseStart = parseBegin - elements.begin();
   element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
                                               parseStart);
@@ -2820,7 +2807,6 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
 
 LogicalResult FormatParser::parseOptionalChildElement(
     std::vector<std::unique_ptr<Element>> &childElements,
-    SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
     Optional<unsigned> &anchorIdx) {
   llvm::SMLoc childLoc = curToken.getLoc();
   childElements.push_back({});
@@ -2837,7 +2823,14 @@ LogicalResult FormatParser::parseOptionalChildElement(
     consumeToken();
   }
 
-  return TypeSwitch<Element *, LogicalResult>(childElements.back().get())
+  return verifyOptionalChildElement(childElements.back().get(), childLoc,
+                                    isAnchor);
+}
+
+LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
+                                                       llvm::SMLoc childLoc,
+                                                       bool isAnchor) {
+  return TypeSwitch<Element *, LogicalResult>(element)
       // All attributes can be within the optional group, but only optional
       // attributes can be the anchor.
       .Case([&](AttributeVariable *attrEle) {
@@ -2852,7 +2845,14 @@ LogicalResult FormatParser::parseOptionalChildElement(
         if (!ele->getVar()->isVariableLength())
           return emitError(childLoc, "only variable length operands can be "
                                      "used within an optional group");
-        seenVariables.insert(ele->getVar());
+        return ::mlir::success();
+      })
+      // Only optional-like(i.e. variadic) results can be within an optional
+      // group.
+      .Case<ResultVariable>([&](ResultVariable *ele) {
+        if (!ele->getVar()->isVariableLength())
+          return emitError(childLoc, "only variable length results can be "
+                                     "used within an optional group");
         return ::mlir::success();
       })
       .Case<RegionVariable>([&](RegionVariable *) {
@@ -2860,16 +2860,27 @@ LogicalResult FormatParser::parseOptionalChildElement(
         // a check here.
         return ::mlir::success();
       })
-      // Literals, whitespace, custom directives, and type directives may be
-      // used, but they can't anchor the group.
-      .Case<LiteralElement, WhitespaceElement, CustomDirective,
-            FunctionalTypeDirective, OptionalElement, TypeRefDirective,
-            TypeDirective>([&](Element *) {
-        if (isAnchor)
-          return emitError(childLoc, "only variables can be used to anchor "
-                                     "an optional group");
-        return ::mlir::success();
+      .Case<TypeDirective>([&](TypeDirective *ele) {
+        return verifyOptionalChildElement(ele->getOperand(), childLoc,
+                                          /*isAnchor=*/false);
       })
+      .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *ele) {
+        if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc,
+                                              /*isAnchor=*/false)))
+          return failure();
+        return verifyOptionalChildElement(ele->getResults(), childLoc,
+                                          /*isAnchor=*/false);
+      })
+      // Literals, whitespace, and custom directives may be used, but they can't
+      // anchor the group.
+      .Case<LiteralElement, WhitespaceElement, CustomDirective,
+            FunctionalTypeDirective, OptionalElement, TypeRefDirective>(
+          [&](Element *) {
+            if (isAnchor)
+              return emitError(childLoc, "only variables and types can be used "
+                                         "to anchor an optional group");
+            return ::mlir::success();
+          })
       .Default([&](Element *) {
         return emitError(childLoc, "only literals, types, and variables can be "
                                    "used within an optional group");


        


More information about the llvm-branch-commits mailing list