[Mlir-commits] [mlir] 60d2769 - [mlir][ods] OpFormat: ensure that regions don't follow `attr-dict`
Jeff Niu
llvmlistbot at llvm.org
Fri Aug 12 18:00:30 PDT 2022
Author: Jeff Niu
Date: 2022-08-12T21:00:25-04:00
New Revision: 60d276923902051192eba692e5312e605c9d9f65
URL: https://github.com/llvm/llvm-project/commit/60d276923902051192eba692e5312e605c9d9f65
DIFF: https://github.com/llvm/llvm-project/commit/60d276923902051192eba692e5312e605c9d9f65.diff
LOG: [mlir][ods] OpFormat: ensure that regions don't follow `attr-dict`
An optional attribute dictionary before a region in an assembly format
is a potential format ambiguity because they both start with `{`.
Fixes #53077
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D131636
Added:
Modified:
mlir/test/mlir-tblgen/op-format-invalid.td
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index aae9ed45b92d..b790360d6df8 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -208,6 +208,13 @@ def DirectiveRegionsInvalidB : TestFormat_Op<[{
def DirectiveRegionsInvalidC : TestFormat_Op<[{
type(regions)
}]>;
+// CHECK: error: format ambiguity caused by `attr-dict` directive followed by region `foo`
+// CHECK: note: try using `attr-dict-with-keyword` instead
+def DirectiveRegionsInvalidD : TestFormat_Op<[{
+ attr-dict $foo
+}]> {
+ let regions = (region AnyRegion:$foo);
+}
//===----------------------------------------------------------------------===//
// results
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 1e03bad9bd47..c1bb87712323 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2213,6 +2213,9 @@ class OpFormatParser : public FormatParser {
/// Verify that attributes elements aren't followed by colon literals.
LogicalResult verifyAttributeColonType(SMLoc loc,
ArrayRef<FormatElement *> elements);
+ /// Verify that the attribute dictionary directive isn't followed by a region.
+ LogicalResult verifyAttrDictRegion(SMLoc loc,
+ ArrayRef<FormatElement *> elements);
/// Verify the state of operation operands within the format.
LogicalResult
@@ -2349,6 +2352,11 @@ OpFormatParser::verifyAttributes(SMLoc loc,
// better to just error out here instead.
if (failed(verifyAttributeColonType(loc, elements)))
return failure();
+ // Check that there are no region variables following an attribute dicitonary.
+ // Both start with `{` and so the optional attribute dictionary can cause
+ // format ambiguities.
+ if (failed(verifyAttrDictRegion(loc, elements)))
+ return failure();
// Check for VariadicOfVariadic variables. The segment attribute of those
// variables will be infered.
@@ -2380,48 +2388,46 @@ static bool isOptionallyParsed(FormatElement *el) {
return isa<WhitespaceElement, AttrDictDirective>(el);
}
-/// Scan the given range of elements from the start for a colon literal,
-/// skipping any optionally-parsed elements. If an optional group is
-/// encountered, this function recurses into the 'then' and 'else' elements to
-/// check if they are invalid. Returns `success` if the range is known to be
-/// valid or `None` if scanning reached the end.
+/// Scan the given range of elements from the start for an invalid format
+/// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
+/// If an optional group is encountered, this function recurses into the 'then'
+/// and 'else' elements to check if they are invalid. Returns `success` if the
+/// range is known to be valid or `None` if scanning reached the end.
///
/// Since the guard element of an optional group is required, this function
/// accepts an optional element pointer to mark it as required.
-static Optional<LogicalResult> checkElementRangeForColon(
- function_ref<LogicalResult(const Twine &)> emitError, StringRef attrName,
+static Optional<LogicalResult> checkRangeForElement(
+ FormatElement *base,
+ function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
FormatElement *optionalGuard = nullptr) {
for (FormatElement *element : elementRange) {
- // Skip optionally parsed elements.
- if (element != optionalGuard && isOptionallyParsed(element))
- continue;
+ // If we encounter an invalid element, return an error.
+ if (isInvalid(base, element))
+ return failure();
// Recurse on optional groups.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- if (Optional<LogicalResult> result = checkElementRangeForColon(
- emitError, attrName, optional->getThenElements(),
+ if (Optional<LogicalResult> result = checkRangeForElement(
+ base, isInvalid, optional->getThenElements(),
// The optional group guard is required for the group.
optional->getThenElements().front()))
if (failed(*result))
return failure();
- if (Optional<LogicalResult> result = checkElementRangeForColon(
- emitError, attrName, optional->getElseElements()))
+ if (Optional<LogicalResult> result = checkRangeForElement(
+ base, isInvalid, optional->getElseElements()))
if (failed(*result))
return failure();
// Skip the optional group.
continue;
}
- // If we encounter anything other than `:`, this range is range.
- auto *literal = dyn_cast<LiteralElement>(element);
- if (!literal || literal->getSpelling() != ":")
- return success();
- // If we encounter `:`, the range is known to be invalid.
- return emitError(
- llvm::formatv("format ambiguity caused by `:` literal found after "
- "attribute `{0}` which does not have a buildable type",
- attrName));
+ // Skip optionally parsed elements.
+ if (element != optionalGuard && isOptionallyParsed(element))
+ continue;
+
+ // We found a closing element that is valid.
+ return success();
}
// Return None to indicate that we reached the end.
return llvm::None;
@@ -2431,46 +2437,42 @@ static Optional<LogicalResult> checkElementRangeForColon(
/// literal, resulting in an ambiguous assembly format. Returns a non-null
/// attribute if verification of said attribute reached the end of the range.
/// Returns null if all attribute elements are verified.
-static FailureOr<AttributeVariable *>
-verifyAttributeColon(function_ref<LogicalResult(const Twine &)> emitError,
- ArrayRef<FormatElement *> elements) {
+static FailureOr<FormatElement *> verifyAdjacentElements(
+ function_ref<bool(FormatElement *)> isBase,
+ function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
+ ArrayRef<FormatElement *> elements) {
for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
// The current attribute being verified.
- AttributeVariable *attr = nullptr;
-
- if ((attr = dyn_cast<AttributeVariable>(*it))) {
- // Check only attributes without type builders or that are known to call
- // the generic attribute parser.
- if (attr->getTypeBuilder() ||
- !(attr->shouldBeQualified() ||
- attr->getVar()->attr.getStorageType() == "::mlir::Attribute"))
- continue;
+ FormatElement *base;
+
+ if (isBase(*it)) {
+ base = *it;
} else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
// Recurse on optional groups.
- FailureOr<AttributeVariable *> thenResult =
- verifyAttributeColon(emitError, optional->getThenElements());
+ FailureOr<FormatElement *> thenResult = verifyAdjacentElements(
+ isBase, isInvalid, optional->getThenElements());
if (failed(thenResult))
return failure();
- FailureOr<AttributeVariable *> elseResult =
- verifyAttributeColon(emitError, optional->getElseElements());
+ FailureOr<FormatElement *> elseResult = verifyAdjacentElements(
+ isBase, isInvalid, optional->getElseElements());
if (failed(elseResult))
return failure();
// If either optional group has an unverified attribute, save it.
// Otherwise, move on to the next element.
- if (!(attr = *thenResult) && !(attr = *elseResult))
+ if (!(base = *thenResult) && !(base = *elseResult))
continue;
} else {
continue;
}
// Verify subsequent elements for potential ambiguities.
- if (Optional<LogicalResult> result = checkElementRangeForColon(
- emitError, attr->getVar()->name, {std::next(it), e})) {
+ if (Optional<LogicalResult> result =
+ checkRangeForElement(base, isInvalid, {std::next(it), e})) {
if (failed(*result))
return failure();
} else {
// Since we reached the end, return the attribute as unverified.
- return attr;
+ return base;
}
}
// All attribute elements are known to be verified.
@@ -2480,8 +2482,52 @@ verifyAttributeColon(function_ref<LogicalResult(const Twine &)> emitError,
LogicalResult
OpFormatParser::verifyAttributeColonType(SMLoc loc,
ArrayRef<FormatElement *> elements) {
- return verifyAttributeColon(
- [&](const Twine &msg) { return emitError(loc, msg); }, elements);
+ auto isBase = [](FormatElement *el) {
+ auto attr = dyn_cast<AttributeVariable>(el);
+ if (!attr)
+ return false;
+ // Check only attributes without type builders or that are known to call
+ // the generic attribute parser.
+ return !attr->getTypeBuilder() &&
+ (attr->shouldBeQualified() ||
+ attr->getVar()->attr.getStorageType() == "::mlir::Attribute");
+ };
+ auto isInvalid = [&](FormatElement *base, FormatElement *el) {
+ auto *literal = dyn_cast<LiteralElement>(el);
+ if (!literal || literal->getSpelling() != ":")
+ return false;
+ // If we encounter `:`, the range is known to be invalid.
+ (void)emitError(
+ loc,
+ llvm::formatv("format ambiguity caused by `:` literal found after "
+ "attribute `{0}` which does not have a buildable type",
+ cast<AttributeVariable>(base)->getVar()->name));
+ return true;
+ };
+ return verifyAdjacentElements(isBase, isInvalid, elements);
+}
+
+LogicalResult
+OpFormatParser::verifyAttrDictRegion(SMLoc loc,
+ ArrayRef<FormatElement *> elements) {
+ auto isBase = [](FormatElement *el) {
+ if (auto *attrDict = dyn_cast<AttrDictDirective>(el))
+ return !attrDict->isWithKeyword();
+ return false;
+ };
+ auto isInvalid = [&](FormatElement *base, FormatElement *el) {
+ auto *region = dyn_cast<RegionVariable>(el);
+ if (!region)
+ return false;
+ (void)emitErrorAndNote(
+ loc,
+ llvm::formatv("format ambiguity caused by `attr-dict` directive "
+ "followed by region `{0}`",
+ region->getVar()->name),
+ "try using `attr-dict-with-keyword` instead");
+ return true;
+ };
+ return verifyAdjacentElements(isBase, isInvalid, elements);
}
LogicalResult OpFormatParser::verifyOperands(
More information about the Mlir-commits
mailing list