[Mlir-commits] [mlir] e3bb363 - [mlir][DeclarativeParser] Emit an error if a `:` follows an attribute with a non-constant type.

River Riddle llvmlistbot at llvm.org
Fri Apr 3 19:25:09 PDT 2020


Author: River Riddle
Date: 2020-04-03T19:23:56-07:00
New Revision: e3bb36370d59ce1bf4c22d2258e830999d8e833d

URL: https://github.com/llvm/llvm-project/commit/e3bb36370d59ce1bf4c22d2258e830999d8e833d
DIFF: https://github.com/llvm/llvm-project/commit/e3bb36370d59ce1bf4c22d2258e830999d8e833d.diff

LOG: [mlir][DeclarativeParser] Emit an error if a `:` follows an attribute with a non-constant type.

Summary: The attribute grammar includes an optional trailing colon type, so for attributes without a constant buildable type this will generally lead to unexpected and undesired behavior. Given that, it's better to just error out on these cases.

Differential Revision: https://reviews.llvm.org/D77293

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
    mlir/include/mlir/IR/OpBase.td
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
index 60613e4e39f8..2883072d4aa9 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td
@@ -20,6 +20,7 @@ def AffineMapAttr : Attr<
     CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
   let storageType = [{ AffineMapAttr }];
   let returnType = [{ AffineMap }];
+  let valueType = Index;
   let constBuilderCall = "AffineMapAttr::get($0)";
 }
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 845fa93878a1..38a402dfe9dc 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -319,7 +319,8 @@ class BuildableType<code builder> {
 def AnyType : Type<CPred<"true">, "any type">;
 
 // None type
-def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">;
+def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">,
+      BuildableType<"$_builder.getType<NoneType>()">;
 
 // Any type from the given list
 class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
@@ -835,6 +836,7 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
 def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> {
   let storageType = [{ BoolAttr }];
   let returnType = [{ bool }];
+  let valueType = I1;
   let constBuilderCall = "$_builder.getBoolAttr($0)";
 }
 
@@ -942,11 +944,18 @@ class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
   let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
   let storageType = [{ StringAttr }];
   let returnType = [{ StringRef }];
+  let valueType = NoneType;
 }
 
 def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
                               "string attribute">;
 
+// String attribute that has a specific value type.
+class TypedStrAttr<Type ty> : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
+                                                    "string attribute"> {
+  let valueType = ty;
+}
+
 // Base class for attributes containing types. Example:
 //   def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
 // defines a type attribute containing an integer type.
@@ -957,6 +966,7 @@ class TypeAttrBase<string retType, string description> :
     description> {
   let storageType = [{ TypeAttr }];
   let returnType = retType;
+  let valueType = NoneType;
   let convertFromStorage = "$_self.getValue().cast<" # retType # ">()";
 }
 
@@ -970,6 +980,7 @@ def UnitAttr : Attr<CPred<"$_self.isa<UnitAttr>()">, "unit attribute"> {
   let constBuilderCall = "$_builder.getUnitAttr()";
   let convertFromStorage = "$_self != nullptr";
   let returnType = "bool";
+  let valueType = NoneType;
   let isOptional = 1;
 }
 
@@ -1166,6 +1177,7 @@ class DictionaryAttrBase : Attr<CPred<"$_self.isa<DictionaryAttr>()">,
                           "dictionary of named attribute values"> {
   let storageType = [{ DictionaryAttr }];
   let returnType = [{ DictionaryAttr }];
+  let valueType = NoneType;
   let convertFromStorage = "$_self";
 }
 
@@ -1285,6 +1297,7 @@ class ArrayAttrBase<Pred condition, string description> :
     Attr<condition, description> {
   let storageType = [{ ArrayAttr }];
   let returnType = [{ ArrayAttr }];
+  let valueType = NoneType;
   let convertFromStorage = "$_self";
 }
 
@@ -1364,6 +1377,7 @@ def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
                         "symbol reference attribute"> {
   let storageType = [{ SymbolRefAttr }];
   let returnType = [{ SymbolRefAttr }];
+  let valueType = NoneType;
   let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
   let convertFromStorage = "$_self";
 }
@@ -1371,6 +1385,7 @@ def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
                                    "flat symbol reference attribute"> {
   let storageType = [{ FlatSymbolRefAttr }];
   let returnType = [{ StringRef }];
+  let valueType = NoneType;
   let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
   let convertFromStorage = "$_self.getValue()";
 }

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a5133abfda2c..3e4f0db65942 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -247,7 +247,7 @@ func @non_type_in_type_array_attr_fail() {
 // CHECK-LABEL: func @string_attr_custom_type
 func @string_attr_custom_type() {
   // CHECK: "string_data" : !foo.string
-  test.string_attr_with_type "string_data"
+  test.string_attr_with_type "string_data" : !foo.string
   return
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 80783001a4c9..061960959cd8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -158,15 +158,8 @@ def TypeArrayAttrOp : TEST_Op<"type_array_attr"> {
   let arguments = (ins TypeArrayAttr:$attr);
 }
 def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
-  let arguments = (ins StrAttr:$attr);
-  let printer = [{ p << getAttr("attr"); }];
-  let parser = [{
-    Attribute attr;
-    Type stringType = OpaqueType::get(Identifier::get("foo",
-                                      result.getContext()), "string",
-                                      result.getContext());
-    return parser.parseAttribute(attr, stringType, "attr", result.attributes);
-  }];
+  let arguments = (ins TypedStrAttr<AnyType>:$attr);
+  let assemblyFormat = "$attr attr-dict";
 }
 
 def StrCaseA: StrEnumAttrCase<"A">;

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index ac5aa259a907..5c3e344f68d7 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -1,4 +1,4 @@
-// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s 2>&1 | FileCheck %s --dump-input-on-failure
+// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s --dump-input-on-failure
 
 // This file contains tests for the specification of the declarative op format.
 
@@ -275,6 +275,21 @@ def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
 }]> {
   let successors = (successor AnySuccessor:$successor);
 }
+// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
+def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
+  $attr `:` attr-dict
+}]>, Arguments<(ins ElementsAttr:$attr)>;
+// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
+def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
+  (`foo` $attr^)? `:` attr-dict
+}]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
+// CHECK-NOT: error:
+def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
+  $attr `:` attr-dict
+}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
+def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
+  (`foo` $attr^)? `:` attr-dict
+}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
 
 //===----------------------------------------------------------------------===//
 // Coverage Checks

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index b9303979d3f4..965a6961b16e 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -92,13 +92,23 @@ class VariableElement : public Element {
   }
   const VarT *getVar() { return var; }
 
-private:
+protected:
   const VarT *var;
 };
 
 /// This class represents a variable that refers to an attribute argument.
-using AttributeVariable =
-    VariableElement<NamedAttribute, Element::Kind::AttributeVariable>;
+struct AttributeVariable
+    : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
+  using VariableElement<NamedAttribute,
+                        Element::Kind::AttributeVariable>::VariableElement;
+
+  /// Return the constant builder call for the type of this attribute, or None
+  /// if it doesn't have one.
+  Optional<StringRef> getTypeBuilder() const {
+    Optional<Type> attrType = var->attr.getValueType();
+    return attrType ? attrType->getBuilderCall() : llvm::None;
+  }
+};
 
 /// This class represents a variable that refers to an operand argument.
 using OperandVariable =
@@ -574,11 +584,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
     // If this attribute has a buildable type, use that when parsing the
     // attribute.
     std::string attrTypeStr;
-    if (Optional<Type> attrType = var->attr.getValueType()) {
-      if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) {
-        llvm::raw_string_ostream os(attrTypeStr);
-        os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
-      }
+    if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
+      llvm::raw_string_ostream os(attrTypeStr);
+      os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
     }
 
     body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
@@ -932,8 +940,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
     }
 
     // Elide the attribute type if it is buildable.
-    Optional<Type> attrType = var->attr.getValueType();
-    if (attrType && attrType->getBuilderCall())
+    if (attr->getTypeBuilder())
       body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";
     else
       body << "  p.printAttribute(" << var->name << "Attr());\n";
@@ -1234,6 +1241,22 @@ class FormatParser {
     Optional<StringRef> transformer;
   };
 
+  /// Verify the state of operation attributes within the format.
+  LogicalResult verifyAttributes(llvm::SMLoc loc);
+
+  /// Verify the state of operation operands within the format.
+  LogicalResult
+  verifyOperands(llvm::SMLoc loc,
+                 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
+
+  /// Verify the state of operation results within the format.
+  LogicalResult
+  verifyResults(llvm::SMLoc loc,
+                llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
+
+  /// Verify the state of operation successors within the format.
+  LogicalResult verifySuccessors(llvm::SMLoc loc);
+
   /// Given the values of an `AllTypesMatch` trait, check for inferable type
   /// resolution.
   void handleAllTypesMatchConstraint(
@@ -1357,37 +1380,86 @@ LogicalResult FormatParser::parse() {
     }
   }
 
-  // Check that all of the result types can be inferred.
-  auto &buildableTypes = fmt.buildableTypes;
-  if (!fmt.allResultTypes) {
-    for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
-      if (seenResultTypes.test(i))
-        continue;
+  // Verify the state of the various operation components.
+  if (failed(verifyAttributes(loc)) ||
+      failed(verifyResults(loc, variableTyResolver)) ||
+      failed(verifyOperands(loc, variableTyResolver)) ||
+      failed(verifySuccessors(loc)))
+    return failure();
 
-      // Check to see if we can infer this type from another variable.
-      auto varResolverIt = variableTyResolver.find(op.getResultName(i));
-      if (varResolverIt != variableTyResolver.end()) {
-        fmt.resultTypes[i].setVariable(varResolverIt->second.type,
-                                       varResolverIt->second.transformer);
-        continue;
+  // Check to see if we are formatting all of the operands.
+  fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
+    return isa<OperandsDirective>(elt.get());
+  });
+  return success();
+}
+
+LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
+  // Check that there are no `:` literals after an attribute without a constant
+  // type. The attribute grammar contains an optional trailing colon type, which
+  // can lead to unexpected and generally unintended behavior. Given that, it is
+  // better to just error out here instead.
+  using ElementsIterT = llvm::pointee_iterator<
+      std::vector<std::unique_ptr<Element>>::const_iterator>;
+  SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
+  iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
+  while (!iteratorStack.empty()) {
+    auto &stackIt = iteratorStack.back();
+    ElementsIterT &it = stackIt.first, e = stackIt.second;
+    while (it != e) {
+      Element *element = &*(it++);
+
+      // Traverse into optional groups.
+      if (auto *optional = dyn_cast<OptionalElement>(element)) {
+        auto elements = optional->getElements();
+        iteratorStack.emplace_back(elements.begin(), elements.end());
+        break;
       }
 
-      // If the result is not variadic, allow for the case where the type has a
-      // builder that we can use.
-      NamedTypeConstraint &result = op.getResult(i);
-      Optional<StringRef> builder = result.constraint.getBuilderCall();
-      if (!builder || result.constraint.isVariadic()) {
-        return emitError(loc, "format missing instance of result #" + Twine(i) +
-                                  "('" + result.name + "') type");
+      // We are checking for an attribute element followed by a `:`, so there is
+      // no need to check the end.
+      if (it == e && iteratorStack.size() == 1)
+        break;
+
+      // Check for an attribute with a constant type builder, followed by a `:`.
+      auto *prevAttr = dyn_cast<AttributeVariable>(element);
+      if (!prevAttr || prevAttr->getTypeBuilder())
+        continue;
+
+      // Check the next iterator within the stack for literal elements.
+      for (auto &nextItPair : iteratorStack) {
+        ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
+        for (; nextIt != nextE; ++nextIt) {
+          // Skip any trailing optional groups or attribute dictionaries.
+          if (isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
+            continue;
+
+          // We are only interested in `:` literals.
+          auto *literal = dyn_cast<LiteralElement>(&*nextIt);
+          if (!literal || literal->getLiteral() != ":")
+            break;
+
+          // TODO: Use the location of the literal element itself.
+          return emitError(
+              loc, llvm::formatv("format ambiguity caused by `:` literal found "
+                                 "after attribute `{0}` which does not have "
+                                 "a buildable type",
+                                 prevAttr->getVar()->name));
+        }
       }
-      // Note in the format that this result uses the custom builder.
-      auto it = buildableTypes.insert({*builder, buildableTypes.size()});
-      fmt.resultTypes[i].setBuilderIdx(it.first->second);
     }
+    if (it == e)
+      iteratorStack.pop_back();
   }
+  return success();
+}
 
+LogicalResult FormatParser::verifyOperands(
+    llvm::SMLoc loc,
+    llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
   // Check that all of the operands are within the format, and their types can
   // be inferred.
+  auto &buildableTypes = fmt.buildableTypes;
   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
     NamedTypeConstraint &operand = op.getOperand(i);
 
@@ -1419,22 +1491,57 @@ LogicalResult FormatParser::parse() {
     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
     fmt.operandTypes[i].setBuilderIdx(it.first->second);
   }
+  return success();
+}
 
-  // Check that all of the successors are within the format.
-  if (!hasAllSuccessors) {
-    for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
-      const NamedSuccessor &successor = op.getSuccessor(i);
-      if (!seenSuccessors.count(&successor)) {
-        return emitError(loc, "format missing instance of successor #" +
-                                  Twine(i) + "('" + successor.name + "')");
-      }
+LogicalResult FormatParser::verifyResults(
+    llvm::SMLoc loc,
+    llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
+  // If we format all of the types together, there is nothing to check.
+  if (fmt.allResultTypes)
+    return success();
+
+  // Check that all of the result types can be inferred.
+  auto &buildableTypes = fmt.buildableTypes;
+  for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
+    if (seenResultTypes.test(i))
+      continue;
+
+    // Check to see if we can infer this type from another variable.
+    auto varResolverIt = variableTyResolver.find(op.getResultName(i));
+    if (varResolverIt != variableTyResolver.end()) {
+      fmt.resultTypes[i].setVariable(varResolverIt->second.type,
+                                     varResolverIt->second.transformer);
+      continue;
     }
+
+    // If the result is not variadic, allow for the case where the type has a
+    // builder that we can use.
+    NamedTypeConstraint &result = op.getResult(i);
+    Optional<StringRef> builder = result.constraint.getBuilderCall();
+    if (!builder || result.constraint.isVariadic()) {
+      return emitError(loc, "format missing instance of result #" + Twine(i) +
+                                "('" + result.name + "') type");
+    }
+    // Note in the format that this result uses the custom builder.
+    auto it = buildableTypes.insert({*builder, buildableTypes.size()});
+    fmt.resultTypes[i].setBuilderIdx(it.first->second);
   }
+  return success();
+}
 
-  // Check to see if we are formatting all of the operands.
-  fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
-    return isa<OperandsDirective>(elt.get());
-  });
+LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
+  // Check that all of the successors are within the format.
+  if (hasAllSuccessors)
+    return success();
+
+  for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
+    const NamedSuccessor &successor = op.getSuccessor(i);
+    if (!seenSuccessors.count(&successor)) {
+      return emitError(loc, "format missing instance of successor #" +
+                                Twine(i) + "('" + successor.name + "')");
+    }
+  }
   return success();
 }
 


        


More information about the Mlir-commits mailing list