[Mlir-commits] [mlir] 1447ec5 - [mlir][AttrDefGen] Add support for specifying the value type of an attribute

River Riddle llvmlistbot at llvm.org
Thu Mar 4 13:04:14 PST 2021


Author: River Riddle
Date: 2021-03-04T13:04:05-08:00
New Revision: 1447ec5182e61b117230ab39af6b53d08470aea3

URL: https://github.com/llvm/llvm-project/commit/1447ec5182e61b117230ab39af6b53d08470aea3
DIFF: https://github.com/llvm/llvm-project/commit/1447ec5182e61b117230ab39af6b53d08470aea3.diff

LOG: [mlir][AttrDefGen] Add support for specifying the value type of an attribute

The value type of the attribute can be specified by either overriding the typeBuilder field on the AttrDef, or by providing a parameter of type `AttributeSelfTypeParameter`. This removes the need to define custom storage class constructors for attributes that have a value type other than NoneType.

Differential Revision: https://reviews.llvm.org/D97590

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index db38107bcc10..c18034c47d4f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2639,6 +2639,14 @@ class AttrDef<Dialect dialect, string name,
   // The name of the C++ Attribute class.
   string cppClassName = name # "Attr";
 
+  // A code block used to build the value 'Type' of an Attribute when
+  // initializing its storage instance. This field is optional, and if not
+  // present the attribute will have its value type set to `NoneType`. This code
+  // block may reference any of the attributes parameters via
+  // `$_<parameter-name`. If one of the parameters of the attribute is of type
+  // `AttributeSelfTypeParameter`, this field is ignored.
+  code typeBuilder = ?;
+
   // The predicate for when this def is used as a constraint.
   let predicate = CPred<"$_self.isa<" # dialect.cppNamespace #
                                  "::" # cppClassName # ">()">;
@@ -2704,4 +2712,10 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
   }];
 }
 
+// This is a special parameter used for AttrDefs that represents a `mlir::Type`
+// that is also used as the value `Type` of the attribute. Only one parameter
+// of the attribute may be of this type.
+class AttributeSelfTypeParameter<string desc> :
+    AttrOrTypeParameter<"::mlir::Type", desc> {}
+
 #endif // OP_BASE

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 8ce752014ebd..84a5a04c909c 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -130,7 +130,10 @@ class AttrOrTypeDef {
   // Returns whether the AttrOrTypeDef is defined.
   operator bool() const { return def != nullptr; }
 
-private:
+  // Return the underlying def.
+  const llvm::Record *getDef() const { return def; }
+
+protected:
   const llvm::Record *def;
 
   // The builders of this type definition.
@@ -145,6 +148,12 @@ class AttrOrTypeDef {
 class AttrDef : public AttrOrTypeDef {
 public:
   using AttrOrTypeDef::AttrOrTypeDef;
+
+  // Returns the attributes value type builder code block, or None if it doesn't
+  // have one.
+  Optional<StringRef> getTypeBuilder() const;
+
+  static bool classof(const AttrOrTypeDef *def);
 };
 
 //===----------------------------------------------------------------------===//
@@ -183,6 +192,9 @@ class AttrOrTypeParameter {
   // Get the assembly syntax documentation.
   StringRef getSyntax() const;
 
+  // Return the underlying def of this parameter.
+  const llvm::Init *getDef() const;
+
 private:
   /// The underlying tablegen parameter list this parameter is a part of.
   const llvm::DagInit *def;
@@ -190,6 +202,17 @@ class AttrOrTypeParameter {
   unsigned index;
 };
 
+//===----------------------------------------------------------------------===//
+// AttributeSelfTypeParameter
+//===----------------------------------------------------------------------===//
+
+// A wrapper class for the AttributeSelfTypeParameter tblgen class. This
+// represents a parameter of mlir::Type that is the value type of an AttrDef.
+class AttributeSelfTypeParameter : public AttrOrTypeParameter {
+public:
+  static bool classof(const AttrOrTypeParameter *param);
+};
+
 } // end namespace tblgen
 } // end namespace mlir
 

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index e82f0f069ddf..037dc4d40cc5 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -153,6 +153,18 @@ bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const {
   return getName() < other.getName();
 }
 
+//===----------------------------------------------------------------------===//
+// AttrDef
+//===----------------------------------------------------------------------===//
+
+Optional<StringRef> AttrDef::getTypeBuilder() const {
+  return def->getValueAsOptionalString("typeBuilder");
+}
+
+bool AttrDef::classof(const AttrOrTypeDef *def) {
+  return def->getDef()->isSubClassOf("AttrDef");
+}
+
 //===----------------------------------------------------------------------===//
 // AttrOrTypeParameter
 //===----------------------------------------------------------------------===//
@@ -219,3 +231,18 @@ StringRef AttrOrTypeParameter::getSyntax() const {
   llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
                         "defs which inherit from AttrOrTypeParameter");
 }
+
+const llvm::Init *AttrOrTypeParameter::getDef() const {
+  return def->getArg(index);
+}
+
+//===----------------------------------------------------------------------===//
+// AttributeSelfTypeParameter
+//===----------------------------------------------------------------------===//
+
+bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
+  const llvm::Init *paramDef = param->getDef();
+  if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
+    return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
+  return false;
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 8b3ebaa6e632..33a94a822a99 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -41,4 +41,17 @@ def CompoundAttrA : Test_Attr<"CompoundA"> {
   );
 }
 
+// An attribute testing AttributeSelfTypeParameter.
+def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> {
+  let mnemonic = "attr_with_self_type_param";
+  let parameters = (ins AttributeSelfTypeParameter<"">:$type);
+}
+
+// An attribute testing AttributeSelfTypeParameter.
+def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
+  let mnemonic = "attr_with_type_builder";
+  let parameters = (ins "::mlir::IntegerAttr":$attr);
+  let typeBuilder = "$_attr.getType()";
+}
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 39328b6c1d10..d13cd5a0bd72 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -23,6 +23,43 @@
 using namespace mlir;
 using namespace mlir::test;
 
+//===----------------------------------------------------------------------===//
+// AttrWithSelfTypeParamAttr
+//===----------------------------------------------------------------------===//
+
+Attribute AttrWithSelfTypeParamAttr::parse(MLIRContext *context,
+                                           DialectAsmParser &parser,
+                                           Type type) {
+  Type selfType;
+  if (parser.parseType(selfType))
+    return Attribute();
+  return get(context, selfType);
+}
+
+void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const {
+  printer << "attr_with_self_type_param " << getType();
+}
+
+//===----------------------------------------------------------------------===//
+// AttrWithTypeBuilderAttr
+//===----------------------------------------------------------------------===//
+
+Attribute AttrWithTypeBuilderAttr::parse(MLIRContext *context,
+                                         DialectAsmParser &parser, Type type) {
+  IntegerAttr element;
+  if (parser.parseAttribute(element))
+    return Attribute();
+  return get(context, element);
+}
+
+void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const {
+  printer << "attr_with_type_builder " << getAttr();
+}
+
+//===----------------------------------------------------------------------===//
+// CompoundAAttr
+//===----------------------------------------------------------------------===//
+
 Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser,
                                Type type) {
   int widthOfSomething;

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e09e003ff565..9d91dc4faad1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -420,6 +420,14 @@ def OperandsHaveSameType :
   let arguments = (ins AnyType:$x, AnyType:$y);
 }
 
+def ResultHasSameTypeAsAttr :
+    TEST_Op<"result_has_same_type_as_attr",
+            [AllTypesMatch<["attr", "result"]>]> {
+  let arguments = (ins AnyAttr:$attr);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "$attr `->` type($result) attr-dict";
+}
+
 def OperandZeroAndResultHaveSameType :
     TEST_Op<"operand0_and_result_have_same_type",
             [AllTypesMatch<["x", "res"]>]> {

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 36ea2cb46ece..802c2a1f793d 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -23,7 +23,7 @@ include "mlir/IR/OpBase.td"
 // DEF-NEXT: ::llvm::StringRef mnemonic, ::mlir::Type type) {
 // DEF: if (mnemonic == ::mlir::test::CompoundAAttr::getMnemonic()) return ::mlir::test::CompoundAAttr::parse(context, parser, type);
 // DEF-NEXT: if (mnemonic == ::mlir::test::IndexAttr::getMnemonic()) return ::mlir::test::IndexAttr::parse(context, parser, type);
-// DEF-NEXT: return ::mlir::Attribute();
+// DEF: return ::mlir::Attribute();
 
 def Test_Dialect: Dialect {
 // DECL-NOT: TestDialect
@@ -49,7 +49,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
       "::mlir::test::SimpleTypeA": $exampleTdType,
       "SomeCppStruct": $exampleCppType,
       ArrayRefParameter<"int", "Matrix dimensions">:$dims,
-      "::mlir::Type":$inner
+      AttributeSelfTypeParameter<"">:$inner
   );
 
   let genVerifyDecl = 1;
@@ -66,6 +66,20 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DECL: int getWidthOfSomething() const;
 // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
 // DECL: SomeCppStruct getExampleCppType() const;
+
+// Check that AttributeSelfTypeParameter is handled properly.
+// DEF-LABEL: struct CompoundAAttrStorage
+// DEF: CompoundAAttrStorage (
+// DEF-NEXT: : ::mlir::AttributeStorage(inner),
+
+// DEF: bool operator==(const KeyTy &key) const {
+// DEF-NEXT: return key == KeyTy(widthOfSomething, exampleTdType, exampleCppType, dims, getType());
+
+// DEF: static CompoundAAttrStorage *construct
+// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
+// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, exampleCppType, dims, inner);
+
+// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); }
 }
 
 def C_IndexAttr : TestAttr<"Index"> {
@@ -94,3 +108,14 @@ def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
 // DECL-LABEL: class SingleParameterAttr
 // DECL-NEXT:                   detail::SingleParameterAttrStorage
 }
+
+// An attribute testing AttributeSelfTypeParameter.
+def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
+  let mnemonic = "attr_with_type_builder";
+  let parameters = (ins "::mlir::IntegerAttr":$attr);
+  let typeBuilder = "$_attr.getType()";
+}
+
+// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
+// DEF: AttrWithTypeBuilderAttrStorage (::mlir::IntegerAttr attr)
+// DEF-NEXT: : ::mlir::AttributeStorage(attr.getType()), attr(attr)

diff  --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
index 8c167ffc2854..75e16817b3e6 100644
--- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -3,3 +3,9 @@
 // CHECK-LABEL: func private @compoundA()
 // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
 func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]>}
+
+// CHECK: test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32
+%a = test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32
+
+// CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16
+%b = test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 3df9fd9ac90d..b7e5c7d27ff4 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -485,31 +485,90 @@ static void emitStorageParameterAllocation(const AttrOrTypeDef &def,
   }
 }
 
+/// Builds a code block that initializes the attribute storage of 'def'.
+/// Attribute initialization is separated from Type initialization given that
+/// the Attribute also needs to initialize its self-type, which has multiple
+/// means of initialization.
+static std::string buildAttributeStorageParamInitializer(
+    const AttrOrTypeDef &def, ArrayRef<AttrOrTypeParameter> parameters) {
+  std::string paramInitializer;
+  llvm::raw_string_ostream paramOS(paramInitializer);
+  paramOS << "::mlir::AttributeStorage(";
+
+  // If this is an attribute, we need to check for value type initialization.
+  Optional<size_t> selfParamIndex;
+  for (auto it : llvm::enumerate(parameters)) {
+    const auto *selfParam = dyn_cast<AttributeSelfTypeParameter>(&it.value());
+    if (!selfParam)
+      continue;
+    if (selfParamIndex) {
+      llvm::PrintFatalError(def.getLoc(),
+                            "Only one attribute parameter can be marked as "
+                            "AttributeSelfTypeParameter");
+    }
+    paramOS << selfParam->getName();
+    selfParamIndex = it.index();
+  }
+
+  // If we didn't find a self param, but the def has a type builder we use that
+  // to construct the type.
+  if (!selfParamIndex) {
+    const AttrDef &attrDef = cast<AttrDef>(def);
+    if (Optional<StringRef> typeBuilder = attrDef.getTypeBuilder()) {
+      FmtContext fmtContext;
+      for (const AttrOrTypeParameter &param : parameters)
+        fmtContext.addSubst(("_" + param.getName()).str(), param.getName());
+      paramOS << tgfmt(*typeBuilder, &fmtContext);
+    }
+  }
+  paramOS << ")";
+
+  // Append the parameters to the initializer.
+  for (auto it : llvm::enumerate(parameters))
+    if (it.index() != selfParamIndex)
+      paramOS << llvm::formatv(", {0}({0})", it.value().getName());
+
+  return paramOS.str();
+}
+
 void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
-  SmallVector<AttrOrTypeParameter, 4> parameters;
-  def.getParameters(parameters);
+  SmallVector<AttrOrTypeParameter, 4> params;
+  def.getParameters(params);
 
-  // Collect the parameter names and types.
-  auto parameterNames =
-      map_range(parameters, [](AttrOrTypeParameter parameter) {
-        return parameter.getName();
-      });
+  // Collect the parameter types.
   auto parameterTypes =
-      map_range(parameters, [](AttrOrTypeParameter parameter) {
+      llvm::map_range(params, [](const AttrOrTypeParameter &parameter) {
         return parameter.getCppType();
       });
-  auto parameterList = join(parameterNames, ", ");
-  auto parameterTypeList = join(parameterTypes, ", ");
+  std::string parameterTypeList = llvm::join(parameterTypes, ", ");
+
+  // Collect the parameter initializer.
+  std::string paramInitializer;
+  if (isAttrGenerator) {
+    paramInitializer = buildAttributeStorageParamInitializer(def, params);
+
+  } else {
+    llvm::raw_string_ostream initOS(paramInitializer);
+    llvm::interleaveComma(params, initOS, [&](const AttrOrTypeParameter &it) {
+      initOS << llvm::formatv("{0}({0})", it.getName());
+    });
+  }
+
+  // Construct the parameter list that is used when a concrete instance of the
+  // storage exists.
+  auto nonStaticParameterNames = llvm::map_range(params, [](const auto &param) {
+    return isa<AttributeSelfTypeParameter>(param) ? "getType()"
+                                                  : param.getName();
+  });
 
   // 1) Emit most of the storage class up until the hashKey body.
   os << formatv(
       defStorageClassBeginStr, def.getStorageNamespace(),
       def.getStorageClassName(),
       ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
-                          parameters, /*prependComma=*/false),
-      ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNameInitializer,
-                          parameters, /*prependComma=*/false),
-      parameterList, parameterTypeList, valueType);
+                          params, /*prependComma=*/false),
+      paramInitializer, llvm::join(nonStaticParameterNames, ", "),
+      parameterTypeList, valueType);
 
   // 2) Emit the haskKey method.
   os << "  static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
@@ -517,7 +576,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
   // Extract each parameter from the key.
   os << "      return ::llvm::hash_combine(";
   llvm::interleaveComma(
-      llvm::seq<unsigned>(0, parameters.size()), os,
+      llvm::seq<unsigned>(0, params.size()), os,
       [&](unsigned it) { os << "std::get<" << it << ">(key)"; });
   os << ");\n    }\n";
 
@@ -535,9 +594,9 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
     // First, unbox the parameters.
     os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(),
                   valueType);
-    for (unsigned i = 0, e = parameters.size(); i < e; ++i) {
+    for (unsigned i = 0, e = params.size(); i < e; ++i) {
       os << formatv("      auto {0} = std::get<{1}>(key);\n",
-                    parameters[i].getName(), i);
+                    params[i].getName(), i);
     }
 
     // Second, reassign the parameter variables with allocation code, if it's
@@ -545,14 +604,18 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
     emitStorageParameterAllocation(def, os);
 
     // Last, return an allocated copy.
+    auto parameterNames = llvm::map_range(
+        params, [](const auto &param) { return param.getName(); });
     os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(),
-                  parameterList);
+                  llvm::join(parameterNames, ", "));
   }
 
   // 4) Emit the parameters as storage class members.
-  for (auto parameter : parameters) {
-    os << "      " << parameter.getCppType() << " " << parameter.getName()
-       << ";\n";
+  for (const AttrOrTypeParameter &parameter : params) {
+    // Attribute value types are not stored as fields in the storage.
+    if (!isa<AttributeSelfTypeParameter>(parameter))
+      os << "      " << parameter.getCppType() << " " << parameter.getName()
+         << ";\n";
   }
   os << "  };\n";
 
@@ -708,10 +771,14 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
     // Otherwise, let the user define the exact accessor definition.
     if (def.genAccessors() && def.genStorageClass()) {
       for (const AttrOrTypeParameter &parameter : parameters) {
+        StringRef paramStorageName = isa<AttributeSelfTypeParameter>(parameter)
+                                         ? "getType()"
+                                         : parameter.getName();
+
         SmallString<16> name = parameter.getName();
         name[0] = llvm::toUpper(name[0]);
         os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
-                      parameter.getCppType(), name, parameter.getName(),
+                      parameter.getCppType(), name, paramStorageName,
                       def.getCppClassName());
       }
     }


        


More information about the Mlir-commits mailing list