[Mlir-commits] [mlir] 60d2769 - [mlir][ods] OpFormat: ensure that regions don't follow `attr-dict`

Jeff Niu llvmlistbot at llvm.org
Fri Aug 12 18:00:30 PDT 2022


Author: Jeff Niu
Date: 2022-08-12T21:00:25-04:00
New Revision: 60d276923902051192eba692e5312e605c9d9f65

URL: https://github.com/llvm/llvm-project/commit/60d276923902051192eba692e5312e605c9d9f65
DIFF: https://github.com/llvm/llvm-project/commit/60d276923902051192eba692e5312e605c9d9f65.diff

LOG: [mlir][ods] OpFormat: ensure that regions don't follow `attr-dict`

An optional attribute dictionary before a region in an assembly format
is a potential format ambiguity because they both start with `{`.

Fixes #53077

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D131636

Added: 
    

Modified: 
    mlir/test/mlir-tblgen/op-format-invalid.td
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index aae9ed45b92d..b790360d6df8 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -208,6 +208,13 @@ def DirectiveRegionsInvalidB : TestFormat_Op<[{
 def DirectiveRegionsInvalidC : TestFormat_Op<[{
   type(regions)
 }]>;
+// CHECK: error: format ambiguity caused by `attr-dict` directive followed by region `foo`
+// CHECK: note: try using `attr-dict-with-keyword` instead
+def DirectiveRegionsInvalidD : TestFormat_Op<[{
+  attr-dict $foo
+}]> {
+  let regions = (region AnyRegion:$foo);
+}
 
 //===----------------------------------------------------------------------===//
 // results

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 1e03bad9bd47..c1bb87712323 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2213,6 +2213,9 @@ class OpFormatParser : public FormatParser {
   /// Verify that attributes elements aren't followed by colon literals.
   LogicalResult verifyAttributeColonType(SMLoc loc,
                                          ArrayRef<FormatElement *> elements);
+  /// Verify that the attribute dictionary directive isn't followed by a region.
+  LogicalResult verifyAttrDictRegion(SMLoc loc,
+                                     ArrayRef<FormatElement *> elements);
 
   /// Verify the state of operation operands within the format.
   LogicalResult
@@ -2349,6 +2352,11 @@ OpFormatParser::verifyAttributes(SMLoc loc,
   // better to just error out here instead.
   if (failed(verifyAttributeColonType(loc, elements)))
     return failure();
+  // Check that there are no region variables following an attribute dicitonary.
+  // Both start with `{` and so the optional attribute dictionary can cause
+  // format ambiguities.
+  if (failed(verifyAttrDictRegion(loc, elements)))
+    return failure();
 
   // Check for VariadicOfVariadic variables. The segment attribute of those
   // variables will be infered.
@@ -2380,48 +2388,46 @@ static bool isOptionallyParsed(FormatElement *el) {
   return isa<WhitespaceElement, AttrDictDirective>(el);
 }
 
-/// Scan the given range of elements from the start for a colon literal,
-/// skipping any optionally-parsed elements. If an optional group is
-/// encountered, this function recurses into the 'then' and 'else' elements to
-/// check if they are invalid. Returns `success` if the range is known to be
-/// valid or `None` if scanning reached the end.
+/// Scan the given range of elements from the start for an invalid format
+/// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
+/// If an optional group is encountered, this function recurses into the 'then'
+/// and 'else' elements to check if they are invalid. Returns `success` if the
+/// range is known to be valid or `None` if scanning reached the end.
 ///
 /// Since the guard element of an optional group is required, this function
 /// accepts an optional element pointer to mark it as required.
-static Optional<LogicalResult> checkElementRangeForColon(
-    function_ref<LogicalResult(const Twine &)> emitError, StringRef attrName,
+static Optional<LogicalResult> checkRangeForElement(
+    FormatElement *base,
+    function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
     iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
     FormatElement *optionalGuard = nullptr) {
   for (FormatElement *element : elementRange) {
-    // Skip optionally parsed elements.
-    if (element != optionalGuard && isOptionallyParsed(element))
-      continue;
+    // If we encounter an invalid element, return an error.
+    if (isInvalid(base, element))
+      return failure();
 
     // Recurse on optional groups.
     if (auto *optional = dyn_cast<OptionalElement>(element)) {
-      if (Optional<LogicalResult> result = checkElementRangeForColon(
-              emitError, attrName, optional->getThenElements(),
+      if (Optional<LogicalResult> result = checkRangeForElement(
+              base, isInvalid, optional->getThenElements(),
               // The optional group guard is required for the group.
               optional->getThenElements().front()))
         if (failed(*result))
           return failure();
-      if (Optional<LogicalResult> result = checkElementRangeForColon(
-              emitError, attrName, optional->getElseElements()))
+      if (Optional<LogicalResult> result = checkRangeForElement(
+              base, isInvalid, optional->getElseElements()))
         if (failed(*result))
           return failure();
       // Skip the optional group.
       continue;
     }
 
-    // If we encounter anything other than `:`, this range is range.
-    auto *literal = dyn_cast<LiteralElement>(element);
-    if (!literal || literal->getSpelling() != ":")
-      return success();
-    // If we encounter `:`, the range is known to be invalid.
-    return emitError(
-        llvm::formatv("format ambiguity caused by `:` literal found after "
-                      "attribute `{0}` which does not have a buildable type",
-                      attrName));
+    // Skip optionally parsed elements.
+    if (element != optionalGuard && isOptionallyParsed(element))
+      continue;
+
+    // We found a closing element that is valid.
+    return success();
   }
   // Return None to indicate that we reached the end.
   return llvm::None;
@@ -2431,46 +2437,42 @@ static Optional<LogicalResult> checkElementRangeForColon(
 /// literal, resulting in an ambiguous assembly format. Returns a non-null
 /// attribute if verification of said attribute reached the end of the range.
 /// Returns null if all attribute elements are verified.
-static FailureOr<AttributeVariable *>
-verifyAttributeColon(function_ref<LogicalResult(const Twine &)> emitError,
-                     ArrayRef<FormatElement *> elements) {
+static FailureOr<FormatElement *> verifyAdjacentElements(
+    function_ref<bool(FormatElement *)> isBase,
+    function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
+    ArrayRef<FormatElement *> elements) {
   for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
     // The current attribute being verified.
-    AttributeVariable *attr = nullptr;
-
-    if ((attr = dyn_cast<AttributeVariable>(*it))) {
-      // Check only attributes without type builders or that are known to call
-      // the generic attribute parser.
-      if (attr->getTypeBuilder() ||
-          !(attr->shouldBeQualified() ||
-            attr->getVar()->attr.getStorageType() == "::mlir::Attribute"))
-        continue;
+    FormatElement *base;
+
+    if (isBase(*it)) {
+      base = *it;
     } else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
       // Recurse on optional groups.
-      FailureOr<AttributeVariable *> thenResult =
-          verifyAttributeColon(emitError, optional->getThenElements());
+      FailureOr<FormatElement *> thenResult = verifyAdjacentElements(
+          isBase, isInvalid, optional->getThenElements());
       if (failed(thenResult))
         return failure();
-      FailureOr<AttributeVariable *> elseResult =
-          verifyAttributeColon(emitError, optional->getElseElements());
+      FailureOr<FormatElement *> elseResult = verifyAdjacentElements(
+          isBase, isInvalid, optional->getElseElements());
       if (failed(elseResult))
         return failure();
       // If either optional group has an unverified attribute, save it.
       // Otherwise, move on to the next element.
-      if (!(attr = *thenResult) && !(attr = *elseResult))
+      if (!(base = *thenResult) && !(base = *elseResult))
         continue;
     } else {
       continue;
     }
 
     // Verify subsequent elements for potential ambiguities.
-    if (Optional<LogicalResult> result = checkElementRangeForColon(
-            emitError, attr->getVar()->name, {std::next(it), e})) {
+    if (Optional<LogicalResult> result =
+            checkRangeForElement(base, isInvalid, {std::next(it), e})) {
       if (failed(*result))
         return failure();
     } else {
       // Since we reached the end, return the attribute as unverified.
-      return attr;
+      return base;
     }
   }
   // All attribute elements are known to be verified.
@@ -2480,8 +2482,52 @@ verifyAttributeColon(function_ref<LogicalResult(const Twine &)> emitError,
 LogicalResult
 OpFormatParser::verifyAttributeColonType(SMLoc loc,
                                          ArrayRef<FormatElement *> elements) {
-  return verifyAttributeColon(
-      [&](const Twine &msg) { return emitError(loc, msg); }, elements);
+  auto isBase = [](FormatElement *el) {
+    auto attr = dyn_cast<AttributeVariable>(el);
+    if (!attr)
+      return false;
+    // Check only attributes without type builders or that are known to call
+    // the generic attribute parser.
+    return !attr->getTypeBuilder() &&
+           (attr->shouldBeQualified() ||
+            attr->getVar()->attr.getStorageType() == "::mlir::Attribute");
+  };
+  auto isInvalid = [&](FormatElement *base, FormatElement *el) {
+    auto *literal = dyn_cast<LiteralElement>(el);
+    if (!literal || literal->getSpelling() != ":")
+      return false;
+    // If we encounter `:`, the range is known to be invalid.
+    (void)emitError(
+        loc,
+        llvm::formatv("format ambiguity caused by `:` literal found after "
+                      "attribute `{0}` which does not have a buildable type",
+                      cast<AttributeVariable>(base)->getVar()->name));
+    return true;
+  };
+  return verifyAdjacentElements(isBase, isInvalid, elements);
+}
+
+LogicalResult
+OpFormatParser::verifyAttrDictRegion(SMLoc loc,
+                                     ArrayRef<FormatElement *> elements) {
+  auto isBase = [](FormatElement *el) {
+    if (auto *attrDict = dyn_cast<AttrDictDirective>(el))
+      return !attrDict->isWithKeyword();
+    return false;
+  };
+  auto isInvalid = [&](FormatElement *base, FormatElement *el) {
+    auto *region = dyn_cast<RegionVariable>(el);
+    if (!region)
+      return false;
+    (void)emitErrorAndNote(
+        loc,
+        llvm::formatv("format ambiguity caused by `attr-dict` directive "
+                      "followed by region `{0}`",
+                      region->getVar()->name),
+        "try using `attr-dict-with-keyword` instead");
+    return true;
+  };
+  return verifyAdjacentElements(isBase, isInvalid, elements);
 }
 
 LogicalResult OpFormatParser::verifyOperands(


        


More information about the Mlir-commits mailing list