[Mlir-commits] [mlir] e3bb363 - [mlir][DeclarativeParser] Emit an error if a `:` follows an attribute with a non-constant type.
River Riddle
llvmlistbot at llvm.org
Fri Apr 3 19:25:09 PDT 2020
Author: River Riddle
Date: 2020-04-03T19:23:56-07:00
New Revision: e3bb36370d59ce1bf4c22d2258e830999d8e833d
URL: https://github.com/llvm/llvm-project/commit/e3bb36370d59ce1bf4c22d2258e830999d8e833d
DIFF: https://github.com/llvm/llvm-project/commit/e3bb36370d59ce1bf4c22d2258e830999d8e833d.diff
LOG: [mlir][DeclarativeParser] Emit an error if a `:` follows an attribute with a non-constant type.
Summary: The attribute grammar includes an optional trailing colon type, so for attributes without a constant buildable type this will generally lead to unexpected and undesired behavior. Given that, it's better to just error out on these cases.
Differential Revision: https://reviews.llvm.org/D77293
Added:
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
mlir/include/mlir/IR/OpBase.td
mlir/test/IR/attribute.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
index 60613e4e39f8..2883072d4aa9 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
@@ -20,6 +20,7 @@ def AffineMapAttr : Attr<
CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
let storageType = [{ AffineMapAttr }];
let returnType = [{ AffineMap }];
+ let valueType = Index;
let constBuilderCall = "AffineMapAttr::get($0)";
}
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 845fa93878a1..38a402dfe9dc 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -319,7 +319,8 @@ class BuildableType<code builder> {
def AnyType : Type<CPred<"true">, "any type">;
// None type
-def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">;
+def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">,
+ BuildableType<"$_builder.getType<NoneType>()">;
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
@@ -835,6 +836,7 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> {
let storageType = [{ BoolAttr }];
let returnType = [{ bool }];
+ let valueType = I1;
let constBuilderCall = "$_builder.getBoolAttr($0)";
}
@@ -942,11 +944,18 @@ class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
let storageType = [{ StringAttr }];
let returnType = [{ StringRef }];
+ let valueType = NoneType;
}
def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
"string attribute">;
+// String attribute that has a specific value type.
+class TypedStrAttr<Type ty> : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
+ "string attribute"> {
+ let valueType = ty;
+}
+
// Base class for attributes containing types. Example:
// def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
// defines a type attribute containing an integer type.
@@ -957,6 +966,7 @@ class TypeAttrBase<string retType, string description> :
description> {
let storageType = [{ TypeAttr }];
let returnType = retType;
+ let valueType = NoneType;
let convertFromStorage = "$_self.getValue().cast<" # retType # ">()";
}
@@ -970,6 +980,7 @@ def UnitAttr : Attr<CPred<"$_self.isa<UnitAttr>()">, "unit attribute"> {
let constBuilderCall = "$_builder.getUnitAttr()";
let convertFromStorage = "$_self != nullptr";
let returnType = "bool";
+ let valueType = NoneType;
let isOptional = 1;
}
@@ -1166,6 +1177,7 @@ class DictionaryAttrBase : Attr<CPred<"$_self.isa<DictionaryAttr>()">,
"dictionary of named attribute values"> {
let storageType = [{ DictionaryAttr }];
let returnType = [{ DictionaryAttr }];
+ let valueType = NoneType;
let convertFromStorage = "$_self";
}
@@ -1285,6 +1297,7 @@ class ArrayAttrBase<Pred condition, string description> :
Attr<condition, description> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayAttr }];
+ let valueType = NoneType;
let convertFromStorage = "$_self";
}
@@ -1364,6 +1377,7 @@ def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
"symbol reference attribute"> {
let storageType = [{ SymbolRefAttr }];
let returnType = [{ SymbolRefAttr }];
+ let valueType = NoneType;
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
let convertFromStorage = "$_self";
}
@@ -1371,6 +1385,7 @@ def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
"flat symbol reference attribute"> {
let storageType = [{ FlatSymbolRefAttr }];
let returnType = [{ StringRef }];
+ let valueType = NoneType;
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
let convertFromStorage = "$_self.getValue()";
}
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a5133abfda2c..3e4f0db65942 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -247,7 +247,7 @@ func @non_type_in_type_array_attr_fail() {
// CHECK-LABEL: func @string_attr_custom_type
func @string_attr_custom_type() {
// CHECK: "string_data" : !foo.string
- test.string_attr_with_type "string_data"
+ test.string_attr_with_type "string_data" : !foo.string
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 80783001a4c9..061960959cd8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -158,15 +158,8 @@ def TypeArrayAttrOp : TEST_Op<"type_array_attr"> {
let arguments = (ins TypeArrayAttr:$attr);
}
def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
- let arguments = (ins StrAttr:$attr);
- let printer = [{ p << getAttr("attr"); }];
- let parser = [{
- Attribute attr;
- Type stringType = OpaqueType::get(Identifier::get("foo",
- result.getContext()), "string",
- result.getContext());
- return parser.parseAttribute(attr, stringType, "attr", result.attributes);
- }];
+ let arguments = (ins TypedStrAttr<AnyType>:$attr);
+ let assemblyFormat = "$attr attr-dict";
}
def StrCaseA: StrEnumAttrCase<"A">;
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index ac5aa259a907..5c3e344f68d7 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -1,4 +1,4 @@
-// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s 2>&1 | FileCheck %s --dump-input-on-failure
+// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s --dump-input-on-failure
// This file contains tests for the specification of the declarative op format.
@@ -275,6 +275,21 @@ def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
}]> {
let successors = (successor AnySuccessor:$successor);
}
+// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
+def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
+ $attr `:` attr-dict
+}]>, Arguments<(ins ElementsAttr:$attr)>;
+// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
+def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
+ (`foo` $attr^)? `:` attr-dict
+}]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
+// CHECK-NOT: error:
+def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
+ $attr `:` attr-dict
+}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
+def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
+ (`foo` $attr^)? `:` attr-dict
+}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
//===----------------------------------------------------------------------===//
// Coverage Checks
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index b9303979d3f4..965a6961b16e 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -92,13 +92,23 @@ class VariableElement : public Element {
}
const VarT *getVar() { return var; }
-private:
+protected:
const VarT *var;
};
/// This class represents a variable that refers to an attribute argument.
-using AttributeVariable =
- VariableElement<NamedAttribute, Element::Kind::AttributeVariable>;
+struct AttributeVariable
+ : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
+ using VariableElement<NamedAttribute,
+ Element::Kind::AttributeVariable>::VariableElement;
+
+ /// Return the constant builder call for the type of this attribute, or None
+ /// if it doesn't have one.
+ Optional<StringRef> getTypeBuilder() const {
+ Optional<Type> attrType = var->attr.getValueType();
+ return attrType ? attrType->getBuilderCall() : llvm::None;
+ }
+};
/// This class represents a variable that refers to an operand argument.
using OperandVariable =
@@ -574,11 +584,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
// If this attribute has a buildable type, use that when parsing the
// attribute.
std::string attrTypeStr;
- if (Optional<Type> attrType = var->attr.getValueType()) {
- if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) {
- llvm::raw_string_ostream os(attrTypeStr);
- os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
- }
+ if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
+ llvm::raw_string_ostream os(attrTypeStr);
+ os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
}
body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
@@ -932,8 +940,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
}
// Elide the attribute type if it is buildable.
- Optional<Type> attrType = var->attr.getValueType();
- if (attrType && attrType->getBuilderCall())
+ if (attr->getTypeBuilder())
body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";
else
body << " p.printAttribute(" << var->name << "Attr());\n";
@@ -1234,6 +1241,22 @@ class FormatParser {
Optional<StringRef> transformer;
};
+ /// Verify the state of operation attributes within the format.
+ LogicalResult verifyAttributes(llvm::SMLoc loc);
+
+ /// Verify the state of operation operands within the format.
+ LogicalResult
+ verifyOperands(llvm::SMLoc loc,
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
+
+ /// Verify the state of operation results within the format.
+ LogicalResult
+ verifyResults(llvm::SMLoc loc,
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
+
+ /// Verify the state of operation successors within the format.
+ LogicalResult verifySuccessors(llvm::SMLoc loc);
+
/// Given the values of an `AllTypesMatch` trait, check for inferable type
/// resolution.
void handleAllTypesMatchConstraint(
@@ -1357,37 +1380,86 @@ LogicalResult FormatParser::parse() {
}
}
- // Check that all of the result types can be inferred.
- auto &buildableTypes = fmt.buildableTypes;
- if (!fmt.allResultTypes) {
- for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
- if (seenResultTypes.test(i))
- continue;
+ // Verify the state of the various operation components.
+ if (failed(verifyAttributes(loc)) ||
+ failed(verifyResults(loc, variableTyResolver)) ||
+ failed(verifyOperands(loc, variableTyResolver)) ||
+ failed(verifySuccessors(loc)))
+ return failure();
- // Check to see if we can infer this type from another variable.
- auto varResolverIt = variableTyResolver.find(op.getResultName(i));
- if (varResolverIt != variableTyResolver.end()) {
- fmt.resultTypes[i].setVariable(varResolverIt->second.type,
- varResolverIt->second.transformer);
- continue;
+ // Check to see if we are formatting all of the operands.
+ fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
+ return isa<OperandsDirective>(elt.get());
+ });
+ return success();
+}
+
+LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
+ // Check that there are no `:` literals after an attribute without a constant
+ // type. The attribute grammar contains an optional trailing colon type, which
+ // can lead to unexpected and generally unintended behavior. Given that, it is
+ // better to just error out here instead.
+ using ElementsIterT = llvm::pointee_iterator<
+ std::vector<std::unique_ptr<Element>>::const_iterator>;
+ SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
+ iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
+ while (!iteratorStack.empty()) {
+ auto &stackIt = iteratorStack.back();
+ ElementsIterT &it = stackIt.first, e = stackIt.second;
+ while (it != e) {
+ Element *element = &*(it++);
+
+ // Traverse into optional groups.
+ if (auto *optional = dyn_cast<OptionalElement>(element)) {
+ auto elements = optional->getElements();
+ iteratorStack.emplace_back(elements.begin(), elements.end());
+ break;
}
- // If the result is not variadic, allow for the case where the type has a
- // builder that we can use.
- NamedTypeConstraint &result = op.getResult(i);
- Optional<StringRef> builder = result.constraint.getBuilderCall();
- if (!builder || result.constraint.isVariadic()) {
- return emitError(loc, "format missing instance of result #" + Twine(i) +
- "('" + result.name + "') type");
+ // We are checking for an attribute element followed by a `:`, so there is
+ // no need to check the end.
+ if (it == e && iteratorStack.size() == 1)
+ break;
+
+ // Check for an attribute with a constant type builder, followed by a `:`.
+ auto *prevAttr = dyn_cast<AttributeVariable>(element);
+ if (!prevAttr || prevAttr->getTypeBuilder())
+ continue;
+
+ // Check the next iterator within the stack for literal elements.
+ for (auto &nextItPair : iteratorStack) {
+ ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
+ for (; nextIt != nextE; ++nextIt) {
+ // Skip any trailing optional groups or attribute dictionaries.
+ if (isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
+ continue;
+
+ // We are only interested in `:` literals.
+ auto *literal = dyn_cast<LiteralElement>(&*nextIt);
+ if (!literal || literal->getLiteral() != ":")
+ break;
+
+ // TODO: Use the location of the literal element itself.
+ return emitError(
+ loc, llvm::formatv("format ambiguity caused by `:` literal found "
+ "after attribute `{0}` which does not have "
+ "a buildable type",
+ prevAttr->getVar()->name));
+ }
}
- // Note in the format that this result uses the custom builder.
- auto it = buildableTypes.insert({*builder, buildableTypes.size()});
- fmt.resultTypes[i].setBuilderIdx(it.first->second);
}
+ if (it == e)
+ iteratorStack.pop_back();
}
+ return success();
+}
+LogicalResult FormatParser::verifyOperands(
+ llvm::SMLoc loc,
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
// Check that all of the operands are within the format, and their types can
// be inferred.
+ auto &buildableTypes = fmt.buildableTypes;
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
NamedTypeConstraint &operand = op.getOperand(i);
@@ -1419,22 +1491,57 @@ LogicalResult FormatParser::parse() {
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.operandTypes[i].setBuilderIdx(it.first->second);
}
+ return success();
+}
- // Check that all of the successors are within the format.
- if (!hasAllSuccessors) {
- for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
- const NamedSuccessor &successor = op.getSuccessor(i);
- if (!seenSuccessors.count(&successor)) {
- return emitError(loc, "format missing instance of successor #" +
- Twine(i) + "('" + successor.name + "')");
- }
+LogicalResult FormatParser::verifyResults(
+ llvm::SMLoc loc,
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
+ // If we format all of the types together, there is nothing to check.
+ if (fmt.allResultTypes)
+ return success();
+
+ // Check that all of the result types can be inferred.
+ auto &buildableTypes = fmt.buildableTypes;
+ for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
+ if (seenResultTypes.test(i))
+ continue;
+
+ // Check to see if we can infer this type from another variable.
+ auto varResolverIt = variableTyResolver.find(op.getResultName(i));
+ if (varResolverIt != variableTyResolver.end()) {
+ fmt.resultTypes[i].setVariable(varResolverIt->second.type,
+ varResolverIt->second.transformer);
+ continue;
}
+
+ // If the result is not variadic, allow for the case where the type has a
+ // builder that we can use.
+ NamedTypeConstraint &result = op.getResult(i);
+ Optional<StringRef> builder = result.constraint.getBuilderCall();
+ if (!builder || result.constraint.isVariadic()) {
+ return emitError(loc, "format missing instance of result #" + Twine(i) +
+ "('" + result.name + "') type");
+ }
+ // Note in the format that this result uses the custom builder.
+ auto it = buildableTypes.insert({*builder, buildableTypes.size()});
+ fmt.resultTypes[i].setBuilderIdx(it.first->second);
}
+ return success();
+}
- // Check to see if we are formatting all of the operands.
- fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
- return isa<OperandsDirective>(elt.get());
- });
+LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
+ // Check that all of the successors are within the format.
+ if (hasAllSuccessors)
+ return success();
+
+ for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
+ const NamedSuccessor &successor = op.getSuccessor(i);
+ if (!seenSuccessors.count(&successor)) {
+ return emitError(loc, "format missing instance of successor #" +
+ Twine(i) + "('" + successor.name + "')");
+ }
+ }
return success();
}
More information about the Mlir-commits
mailing list