[Mlir-commits] [mlir] 425e11e - [mlir][AttrTypeDefGen] Add support for custom parameter comparators

River Riddle llvmlistbot at llvm.org
Tue Mar 16 16:39:20 PDT 2021


Author: River Riddle
Date: 2021-03-16T16:31:53-07:00
New Revision: 425e11eea1de022bd9ea099970b451f77b0a4fca

URL: https://github.com/llvm/llvm-project/commit/425e11eea1de022bd9ea099970b451f77b0a4fca
DIFF: https://github.com/llvm/llvm-project/commit/425e11eea1de022bd9ea099970b451f77b0a4fca.diff

LOG: [mlir][AttrTypeDefGen] Add support for custom parameter comparators

Some parameters to attributes and types rely on special comparison routines other than operator== to ensure equality. This revision adds support for those parameters by allowing them to specify a `comparator` code block that determines if `$_lhs` and `$_rhs` are equal. An example of one of these paramters is APFloat, which requires `bitwiseIsEqual` for bitwise comparison (which we want for attribute equality).

Differential Revision: https://reviews.llvm.org/D98473

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 4128543974b3..63b727ae428b 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1441,8 +1441,9 @@ need allocation in the storage constructor, there are two options:
 ### TypeParameter tablegen class
 
 This is used to further specify attributes about each of the types parameters.
-It includes documentation (`summary` and `syntax`), the C++ type to use, and a
-custom allocator to use in the storage constructor method.
+It includes documentation (`summary` and `syntax`), the C++ type to use, a
+custom allocator to use in the storage constructor method, and a custom
+comparator to decide if two instances of the parameter type are equal.
 
 ```tablegen
 // DO NOT DO THIS!
@@ -1472,6 +1473,11 @@ The `allocator` code block has the following substitutions:
 -   `$_allocator` is the TypeStorageAllocator in which to allocate objects.
 -   `$_dst` is the variable in which to place the allocated data.
 
+The `comparator` code block has the following substitutions:
+
+-   `$_lhs` is an instance of the parameter type.
+-   `$_rhs` is an instance of the parameter type.
+
 MLIR includes several specialized classes for common situations:
 
 -   `StringRefParameter<descriptionOfParam>` for StringRefs.

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 5a7037af63d2..819badc8b0f4 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2673,6 +2673,8 @@ class TypeDef<Dialect dialect, string name,
 class AttrOrTypeParameter<string type, string desc> {
   // Custom memory allocation code for storage constructor.
   code allocator = ?;
+  // Custom comparator used to compare two instances for equality.
+  code comparator = ?;
   // The C++ type of this parameter.
   string cppType = type;
   // One-line human-readable description of the argument.
@@ -2689,6 +2691,12 @@ class StringRefParameter<string desc = ""> :
   let allocator = [{$_dst = $_allocator.copyInto($_self);}];
 }
 
+// For APFloats, which require comparison.
+class APFloatParameter<string desc> :
+    AttrOrTypeParameter<"::llvm::APFloat", desc> {
+  let comparator = "$_lhs.bitwiseIsEqual($_rhs)";
+}
+
 // For standard ArrayRefs, which require allocation.
 class ArrayRefParameter<string arrayOf, string desc = ""> :
     AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 84a5a04c909c..5271fbae0eea 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -183,6 +183,9 @@ class AttrOrTypeParameter {
   // If specified, get the custom allocator code for this parameter.
   Optional<StringRef> getAllocator() const;
 
+  // If specified, get the custom comparator code for this parameter.
+  Optional<StringRef> getComparator() const;
+
   // Get the C++ type of this parameter.
   StringRef getCppType() const;
 

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 037dc4d40cc5..eea03015d329 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -177,22 +177,18 @@ Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
   llvm::Init *parameterType = def->getArg(index);
   if (isa<llvm::StringInit>(parameterType))
     return Optional<StringRef>();
+  if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
+    return param->getDef()->getValueAsOptionalString("allocator");
+  llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+                        "defs which inherit from AttrOrTypeParameter\n");
+}
 
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
-    llvm::RecordVal *code = param->getDef()->getValue("allocator");
-    if (!code)
-      return Optional<StringRef>();
-    if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(code->getValue()))
-      return ci->getValue();
-    if (isa<llvm::UnsetInit>(code->getValue()))
-      return Optional<StringRef>();
-
-    llvm::PrintFatalError(
-        param->getDef()->getLoc(),
-        "Record `" + def->getArgName(index)->getValue() +
-            "', field `printer' does not have a code initializer!");
-  }
-
+Optional<StringRef> AttrOrTypeParameter::getComparator() const {
+  llvm::Init *parameterType = def->getArg(index);
+  if (isa<llvm::StringInit>(parameterType))
+    return Optional<StringRef>();
+  if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
+    return param->getDef()->getValueAsOptionalString("comparator");
   llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
                         "defs which inherit from AttrOrTypeParameter\n");
 }

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 34cdcab7246e..252b9175b05d 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -53,7 +53,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
       ins
       "int":$widthOfSomething,
       "::mlir::test::SimpleTypeA": $exampleTdType,
-      "SomeCppStruct": $exampleCppType,
+      APFloatParameter<"">: $apFloat,
       ArrayRefParameter<"int", "Matrix dimensions">:$dims,
       AttributeSelfTypeParameter<"">:$inner
   );
@@ -61,8 +61,8 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
   let genVerifyDecl = 1;
 
 // DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
-// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
-// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
 // DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
 // DECL:   return ::llvm::StringLiteral("cmpnd_a");
 // DECL: }
@@ -71,7 +71,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DECL: void print(::mlir::DialectAsmPrinter &printer) const;
 // DECL: int getWidthOfSomething() const;
 // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
-// DECL: SomeCppStruct getExampleCppType() const;
+// DECL: ::llvm::APFloat getApFloat() const;
 
 // Check that AttributeSelfTypeParameter is handled properly.
 // DEF-LABEL: struct CompoundAAttrStorage
@@ -79,11 +79,21 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DEF-NEXT: : ::mlir::AttributeStorage(inner),
 
 // DEF: bool operator==(const KeyTy &key) const {
-// DEF-NEXT: return key == KeyTy(widthOfSomething, exampleTdType, exampleCppType, dims, getType());
+// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key)))
+// DEF-NEXT:   return false;
+// DEF-NEXT: if (!(exampleTdType == std::get<1>(key)))
+// DEF-NEXT:   return false;
+// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key))))
+// DEF-NEXT:   return false;
+// DEF-NEXT: if (!(dims == std::get<3>(key)))
+// DEF-NEXT:   return false;
+// DEF-NEXT: if (!(getType() == std::get<4>(key)))
+// DEF-NEXT:   return false;
+// DEF-NEXT: return true;
 
 // DEF: static CompoundAAttrStorage *construct
 // DEF: return new (allocator.allocate<CompoundAAttrStorage>())
-// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, exampleCppType, dims, inner);
+// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
 
 // DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); }
 }

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2c3caaeeea25..636d4f8b51ef 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -432,22 +432,16 @@ static ::mlir::LogicalResult generated{0}Printer(
 /// {1}: Storage class c++ name.
 /// {2}: Parameters parameters.
 /// {3}: Parameter initializer string.
-/// {4}: Parameter name list.
-/// {5}: Parameter types.
-/// {6}: The name of the base value type, e.g. Attribute or Type.
+/// {4}: Parameter types.
+/// {5}: The name of the base value type, e.g. Attribute or Type.
 static const char *const defStorageClassBeginStr = R"(
 namespace {0} {{
-  struct {1} : public ::mlir::{6}Storage {{
+  struct {1} : public ::mlir::{5}Storage {{
     {1} ({2})
       : {3} {{ }
 
     /// The hash key is a tuple of the parameter types.
-    using KeyTy = std::tuple<{5}>;
-
-    /// Define the comparison function for the key type.
-    bool operator==(const KeyTy &key) const {{
-      return key == KeyTy({4});
-    }
+    using KeyTy = std::tuple<{4}>;
 )";
 
 /// The storage class' constructor template.
@@ -555,23 +549,34 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
     });
   }
 
-  // Construct the parameter list that is used when a concrete instance of the
-  // storage exists.
-  auto nonStaticParameterNames = llvm::map_range(params, [](const auto &param) {
-    return isa<AttributeSelfTypeParameter>(param) ? "getType()"
-                                                  : param.getName();
-  });
-
-  // 1) Emit most of the storage class up until the hashKey body.
+  // * Emit most of the storage class up until the hashKey body.
   os << formatv(
       defStorageClassBeginStr, def.getStorageNamespace(),
       def.getStorageClassName(),
       ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
                           params, /*prependComma=*/false),
-      paramInitializer, llvm::join(nonStaticParameterNames, ", "),
-      parameterTypeList, valueType);
+      paramInitializer, parameterTypeList, valueType);
+
+  // * Emit the comparison method.
+  os << "  bool operator==(const KeyTy &key) const {\n";
+  for (auto it : llvm::enumerate(params)) {
+    os << "    if (!(";
+
+    // Build the comparator context.
+    bool isSelfType = isa<AttributeSelfTypeParameter>(it.value());
+    FmtContext context;
+    context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName())
+        .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)");
+
+    // Use the parameter specified comparator if possible, otherwise default to
+    // operator==.
+    Optional<StringRef> comparator = it.value().getComparator();
+    os << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &context);
+    os << "))\n      return false;\n";
+  }
+  os << "    return true;\n  }\n";
 
-  // 2) Emit the haskKey method.
+  // * Emit the haskKey method.
   os << "  static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
 
   // Extract each parameter from the key.
@@ -581,7 +586,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
       [&](unsigned it) { os << "std::get<" << it << ">(key)"; });
   os << ");\n    }\n";
 
-  // 3) Emit the construct method.
+  // * Emit the construct method.
 
   // If user wants to build the storage constructor themselves, declare it
   // here and then they can write the definition elsewhere.
@@ -611,7 +616,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
                   llvm::join(parameterNames, ", "));
   }
 
-  // 4) Emit the parameters as storage class members.
+  // * Emit the parameters as storage class members.
   for (const AttrOrTypeParameter &parameter : params) {
     // Attribute value types are not stored as fields in the storage.
     if (!isa<AttributeSelfTypeParameter>(parameter))


        


More information about the Mlir-commits mailing list