[Mlir-commits] [mlir] 1b2c16f - [mlir][DeclarativeParser] Add support for attributes with buildable types.

River Riddle llvmlistbot at llvm.org
Sat Feb 8 15:52:50 PST 2020


Author: River Riddle
Date: 2020-02-08T15:46:46-08:00
New Revision: 1b2c16f2ae41eff124a11b8a5c343fb7688a2a85

URL: https://github.com/llvm/llvm-project/commit/1b2c16f2ae41eff124a11b8a5c343fb7688a2a85
DIFF: https://github.com/llvm/llvm-project/commit/1b2c16f2ae41eff124a11b8a5c343fb7688a2a85.diff

LOG: [mlir][DeclarativeParser] Add support for attributes with buildable types.

This revision adds support in the declarative assembly form for printing attributes with buildable types without the type, and moves several more parsers over to the declarative form.

Differential Revision: https://reviews.llvm.org/D74276

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/TableGen/Attribute.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/TableGen/Attribute.cpp
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 6e2bfa53827c..57ad925b1b07 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -616,6 +616,9 @@ A variable is an entity that has been registered on the operation itself, i.e.
 an argument(attribute or operand), result, etc. In the `CallOp` example above,
 the variables would be `$callee`  and `$args`.
 
+Attribute variables are printed with their respective value type, unless that
+value type is buildable. In those cases, the type of the attribute is elided.
+
 #### Requirements
 
 The format specification has a certain set of requirements that must be adhered

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 399bd478322c..c610e0b4e911 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -629,6 +629,10 @@ class Attr<Pred condition, string descr = ""> :
   // Requires a constBuilderCall defined.
   string defaultValue = ?;
 
+  // The value type of this attribute. This corresponds to the mlir::Type that
+  // this attribute returns via `getType()`.
+  Type valueType = ?;
+
   // Whether the attribute is optional. Typically requires a custom
   // convertFromStorage method to handle the case where the attribute is
   // not present.
@@ -660,6 +664,7 @@ class DefaultValuedAttr<Attr attr, string val> :
   let convertFromStorage = attr.convertFromStorage;
   let constBuilderCall = attr.constBuilderCall;
   let defaultValue = val;
+  let valueType = attr.valueType;
 
   let baseAttr = attr;
 }
@@ -673,6 +678,7 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
   let returnType = "Optional<" # attr.returnType #">";
   let convertFromStorage = "$_self ? " # returnType # "(" #
                            attr.convertFromStorage # ") : (llvm::None)";
+  let valueType = attr.valueType;
   let isOptional = 1;
 
   let baseAttr = attr;
@@ -681,14 +687,15 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
 //===----------------------------------------------------------------------===//
 // Primitive attribute kinds
 
-// A generic attribute that must be constructed around a specific type
+// A generic attribute that must be constructed around a specific buildable type
 // `attrValType`. Backed by MLIR attribute kind `attrKind`.
-class TypedAttrBase<BuildableType attrValType, string attrKind,
-                    Pred condition, string descr> :
+class TypedAttrBase<Type attrValType, string attrKind, Pred condition,
+                    string descr> :
     Attr<condition, descr> {
   let constBuilderCall = "$_builder.get" # attrKind # "(" #
                          attrValType.builderCall # ", $0)";
   let storageType = attrKind;
+  let valueType = attrValType;
 }
 
 // Any attribute.
@@ -1227,6 +1234,7 @@ class Confined<Attr attr, list<AttrConstraint> constraints> : Attr<
   let convertFromStorage = attr.convertFromStorage;
   let constBuilderCall = attr.constBuilderCall;
   let defaultValue = attr.defaultValue;
+  let valueType = attr.valueType;
   let isOptional = attr.isOptional;
 
   let baseAttr = attr;

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index c3bd683d1d59..47bfe69487d9 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -58,6 +58,10 @@ class OpAsmPrinter {
   virtual void printType(Type type) = 0;
   virtual void printAttribute(Attribute attr) = 0;
 
+  /// Print the given attribute without its type. The corresponding parser must
+  /// provide a valid type for the attribute.
+  virtual void printAttributeWithoutType(Attribute attr) = 0;
+
   /// Print a successor, and use list, of a terminator operation given the
   /// terminator and the successor index.
   virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0;

diff  --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index cd41109c6410..dbc018a09323 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -25,6 +25,7 @@ class Record;
 
 namespace mlir {
 namespace tblgen {
+class Type;
 
 // Wrapper class with helper methods for accessing attribute constraints defined
 // in TableGen.
@@ -54,6 +55,10 @@ class Attribute : public AttrConstraint {
   // Returns the return type for this attribute.
   StringRef getReturnType() const;
 
+  // Return the type constraint corresponding to the type of this attribute, or
+  // None if this is not a TypedAttr.
+  llvm::Optional<Type> getValueType() const;
+
   // Returns the template getter method call which reads this attribute's
   // storage and returns the value as of the desired return type.
   // The call will contain a `{0}` which will be expanded to this attribute.

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f8abf9a3e95b..f3d884f75781 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -855,10 +855,21 @@ class ModulePrinter {
     mlir::interleaveComma(c, os, each_fn);
   }
 
-  /// Print the given attribute. If 'mayElideType' is true, some attributes are
-  /// printed without the type when the type matches the default used in the
-  /// parser (for example i64 is the default for integer attributes).
-  void printAttribute(Attribute attr, bool mayElideType = false);
+  /// This enum descripes the 
diff erent kinds of elision for the type of an
+  /// attribute when printing it.
+  enum class AttrTypeElision {
+    /// The type must not be elided,
+    Never,
+    /// The type may be elided when it matches the default used in the parser
+    /// (for example i64 is the default for integer attributes).
+    May,
+    /// The type must be elided.
+    Must
+  };
+
+  /// Print the given attribute.
+  void printAttribute(Attribute attr,
+                      AttrTypeElision typeElision = AttrTypeElision::Never);
 
   void printType(Type type);
   void printLocation(LocationAttr loc);
@@ -1185,7 +1196,8 @@ static void printElidedElementsAttr(raw_ostream &os) {
   os << R"(opaque<"", "0xDEADBEEF">)";
 }
 
-void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
+void ModulePrinter::printAttribute(Attribute attr,
+                                   AttrTypeElision typeElision) {
   if (!attr) {
     os << "<<NULL ATTRIBUTE>>";
     return;
@@ -1200,6 +1212,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
     }
   }
 
+  auto attrType = attr.getType();
   switch (attr.getKind()) {
   default:
     return printDialectAttribute(attr);
@@ -1236,12 +1249,11 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
   case StandardAttributes::Integer: {
     auto intAttr = attr.cast<IntegerAttr>();
     // Print all integer attributes as signed unless i1.
-    bool isSigned = intAttr.getType().isIndex() ||
-                    intAttr.getType().getIntOrFloatBitWidth() != 1;
+    bool isSigned = attrType.isIndex() || attrType.getIntOrFloatBitWidth() != 1;
     intAttr.getValue().print(os, isSigned);
 
     // IntegerAttr elides the type if I64.
-    if (mayElideType && intAttr.getType().isInteger(64))
+    if (typeElision == AttrTypeElision::May && attrType.isInteger(64))
       return;
     break;
   }
@@ -1250,7 +1262,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
     printFloatValue(floatAttr.getValue(), os);
 
     // FloatAttr elides the type if F64.
-    if (mayElideType && floatAttr.getType().isF64())
+    if (typeElision == AttrTypeElision::May && attrType.isF64())
       return;
     break;
   }
@@ -1262,7 +1274,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
   case StandardAttributes::Array:
     os << '[';
     interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
-      printAttribute(attr, /*mayElideType=*/true);
+      printAttribute(attr, AttrTypeElision::May);
     });
     os << ']';
     break;
@@ -1339,9 +1351,8 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
     break;
   }
 
-  // Print the type if it isn't a 'none' type.
-  auto attrType = attr.getType();
-  if (!attrType.isa<NoneType>()) {
+  // Don't print the type if we must elide it, or if it is a None type.
+  if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
     os << " : ";
     printType(attrType);
   }
@@ -1904,6 +1915,12 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
     ModulePrinter::printAttribute(attr);
   }
 
+  /// Print the given attribute without its type. The corresponding parser must
+  /// provide a valid type for the attribute.
+  void printAttributeWithoutType(Attribute attr) override {
+    ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
+  }
+
   /// Print the ID for the given value.
   void printOperand(Value value) override { printValueID(value); }
 

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 1a61f8a398a3..7d8535f3cee3 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -285,14 +285,14 @@ class Parser {
   Attribute parseDecOrHexAttr(Type type, bool isNegative);
 
   /// Parse an opaque elements attribute.
-  Attribute parseOpaqueElementsAttr();
+  Attribute parseOpaqueElementsAttr(Type attrType);
 
   /// Parse a dense elements attribute.
-  Attribute parseDenseElementsAttr();
-  ShapedType parseElementsLiteralType();
+  Attribute parseDenseElementsAttr(Type attrType);
+  ShapedType parseElementsLiteralType(Type type);
 
   /// Parse a sparse elements attribute.
-  Attribute parseSparseElementsAttr();
+  Attribute parseSparseElementsAttr(Type attrType);
 
   //===--------------------------------------------------------------------===//
   // Location Parsing
@@ -1505,7 +1505,7 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse a dense elements attribute.
   case Token::kw_dense:
-    return parseDenseElementsAttr();
+    return parseDenseElementsAttr(type);
 
   // Parse a dictionary attribute.
   case Token::l_brace: {
@@ -1543,11 +1543,11 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse an opaque elements attribute.
   case Token::kw_opaque:
-    return parseOpaqueElementsAttr();
+    return parseOpaqueElementsAttr(type);
 
   // Parse a sparse elements attribute.
   case Token::kw_sparse:
-    return parseSparseElementsAttr();
+    return parseSparseElementsAttr(type);
 
   // Parse a string attribute.
   case Token::string: {
@@ -1783,7 +1783,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
 }
 
 /// Parse an opaque elements attribute.
-Attribute Parser::parseOpaqueElementsAttr() {
+Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
   consumeToken(Token::kw_opaque);
   if (parseToken(Token::less, "expected '<' after 'opaque'"))
     return nullptr;
@@ -1816,11 +1816,10 @@ Attribute Parser::parseOpaqueElementsAttr() {
     return (emitError("opaque string only contains hex digits"), nullptr);
 
   consumeToken(Token::string);
-  if (parseToken(Token::greater, "expected '>'") ||
-      parseToken(Token::colon, "expected ':'"))
+  if (parseToken(Token::greater, "expected '>'"))
     return nullptr;
 
-  auto type = parseElementsLiteralType();
+  auto type = parseElementsLiteralType(attrType);
   if (!type)
     return nullptr;
 
@@ -2086,7 +2085,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
 }
 
 /// Parse a dense elements attribute.
-Attribute Parser::parseDenseElementsAttr() {
+Attribute Parser::parseDenseElementsAttr(Type attrType) {
   consumeToken(Token::kw_dense);
   if (parseToken(Token::less, "expected '<' after 'dense'"))
     return nullptr;
@@ -2096,12 +2095,11 @@ Attribute Parser::parseDenseElementsAttr() {
   if (literalParser.parse())
     return nullptr;
 
-  if (parseToken(Token::greater, "expected '>'") ||
-      parseToken(Token::colon, "expected ':'"))
+  if (parseToken(Token::greater, "expected '>'"))
     return nullptr;
 
   auto typeLoc = getToken().getLoc();
-  auto type = parseElementsLiteralType();
+  auto type = parseElementsLiteralType(attrType);
   if (!type)
     return nullptr;
   return literalParser.getAttr(typeLoc, type);
@@ -2112,10 +2110,14 @@ Attribute Parser::parseDenseElementsAttr() {
 ///   elements-literal-type ::= vector-type | ranked-tensor-type
 ///
 /// This method also checks the type has static shape.
-ShapedType Parser::parseElementsLiteralType() {
-  auto type = parseType();
-  if (!type)
-    return nullptr;
+ShapedType Parser::parseElementsLiteralType(Type type) {
+  // If the user didn't provide a type, parse the colon type for the literal.
+  if (!type) {
+    if (parseToken(Token::colon, "expected ':'"))
+      return nullptr;
+    if (!(type = parseType()))
+      return nullptr;
+  }
 
   if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
     emitError("elements literal must be a ranked tensor or vector type");
@@ -2130,7 +2132,7 @@ ShapedType Parser::parseElementsLiteralType() {
 }
 
 /// Parse a sparse elements attribute.
-Attribute Parser::parseSparseElementsAttr() {
+Attribute Parser::parseSparseElementsAttr(Type attrType) {
   consumeToken(Token::kw_sparse);
   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
     return nullptr;
@@ -2150,11 +2152,10 @@ Attribute Parser::parseSparseElementsAttr() {
   if (valuesParser.parse())
     return nullptr;
 
-  if (parseToken(Token::greater, "expected '>'") ||
-      parseToken(Token::colon, "expected ':'"))
+  if (parseToken(Token::greater, "expected '>'"))
     return nullptr;
 
-  auto type = parseElementsLiteralType();
+  auto type = parseElementsLiteralType(attrType);
   if (!type)
     return nullptr;
 

diff  --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 6f25c8433cda..b11438c2dc02 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -75,6 +75,14 @@ StringRef tblgen::Attribute::getReturnType() const {
   return getValueAsString(init);
 }
 
+// Return the type constraint corresponding to the type of this attribute, or
+// None if this is not a TypedAttr.
+llvm::Optional<tblgen::Type> tblgen::Attribute::getValueType() const {
+  if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
+    return tblgen::Type(defInit->getDef());
+  return llvm::None;
+}
+
 StringRef tblgen::Attribute::getConvertFromStorageCall() const {
   const auto *init = def->getValueInit("convertFromStorage");
   return getValueAsString(init);

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index d1210d64b7b6..22e834f4caa5 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -268,9 +268,10 @@ struct OperationFormat {
 ///
 /// {0}: The storage type of the attribute.
 /// {1}: The name of the attribute.
+/// {2}: The type for the attribute.
 const char *const attrParserCode = R"(
   {0} {1}Attr;
-  if (parser.parseAttribute({1}Attr, "{1}", result.attributes))
+  if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
     return failure();
 )";
 
@@ -368,6 +369,10 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
       OpMethod::MP_Static);
   auto &body = method.body();
 
+  // A format context used when parsing attributes with buildable types.
+  FmtContext attrTypeCtx;
+  attrTypeCtx.withBuilder("parser.getBuilder()");
+
   // Generate parsers for each of the elements.
   for (auto &element : elements) {
     /// Literals.
@@ -377,7 +382,19 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
       /// Arguments.
     } else if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
       const NamedAttribute *var = attr->getVar();
-      body << formatv(attrParserCode, var->attr.getStorageType(), var->name);
+
+      // If this attribute has a buildable type, use that when parsing the
+      // attribute.
+      std::string attrTypeStr;
+      if (Optional<Type> attrType = var->attr.getValueType()) {
+        if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) {
+          llvm::raw_string_ostream os(attrTypeStr);
+          os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
+        }
+      }
+
+      body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
+                      attrTypeStr);
     } else if (auto *operand = dyn_cast<OperandVariable>(element.get())) {
       bool isVariadic = operand->getVar()->isVariadic();
       body << formatv(isVariadic ? variadicOperandParserCode
@@ -615,7 +632,14 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
     shouldEmitSpace = true;
 
     if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
-      body << "  p << " << attr->getVar()->name << "Attr();\n";
+      const NamedAttribute *var = attr->getVar();
+
+      // Elide the attribute type if it is buildable..
+      Optional<Type> attrType = var->attr.getValueType();
+      if (attrType && attrType->getBuilderCall())
+        body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";
+      else
+        body << "  p.printAttribute(" << var->name << "Attr());\n";
     } else if (auto *operand = dyn_cast<OperandVariable>(element.get())) {
       body << "  p << " << operand->getVar()->name << "();\n";
     } else if (isa<OperandsDirective>(element.get())) {


        


More information about the Mlir-commits mailing list