[Mlir-commits] [mlir] 425e11e - [mlir][AttrTypeDefGen] Add support for custom parameter comparators
River Riddle
llvmlistbot at llvm.org
Tue Mar 16 16:39:20 PDT 2021
Author: River Riddle
Date: 2021-03-16T16:31:53-07:00
New Revision: 425e11eea1de022bd9ea099970b451f77b0a4fca
URL: https://github.com/llvm/llvm-project/commit/425e11eea1de022bd9ea099970b451f77b0a4fca
DIFF: https://github.com/llvm/llvm-project/commit/425e11eea1de022bd9ea099970b451f77b0a4fca.diff
LOG: [mlir][AttrTypeDefGen] Add support for custom parameter comparators
Some parameters to attributes and types rely on special comparison routines other than operator== to ensure equality. This revision adds support for those parameters by allowing them to specify a `comparator` code block that determines if `$_lhs` and `$_rhs` are equal. An example of one of these paramters is APFloat, which requires `bitwiseIsEqual` for bitwise comparison (which we want for attribute equality).
Differential Revision: https://reviews.llvm.org/D98473
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/mlir-tblgen/attrdefs.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 4128543974b3..63b727ae428b 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1441,8 +1441,9 @@ need allocation in the storage constructor, there are two options:
### TypeParameter tablegen class
This is used to further specify attributes about each of the types parameters.
-It includes documentation (`summary` and `syntax`), the C++ type to use, and a
-custom allocator to use in the storage constructor method.
+It includes documentation (`summary` and `syntax`), the C++ type to use, a
+custom allocator to use in the storage constructor method, and a custom
+comparator to decide if two instances of the parameter type are equal.
```tablegen
// DO NOT DO THIS!
@@ -1472,6 +1473,11 @@ The `allocator` code block has the following substitutions:
- `$_allocator` is the TypeStorageAllocator in which to allocate objects.
- `$_dst` is the variable in which to place the allocated data.
+The `comparator` code block has the following substitutions:
+
+- `$_lhs` is an instance of the parameter type.
+- `$_rhs` is an instance of the parameter type.
+
MLIR includes several specialized classes for common situations:
- `StringRefParameter<descriptionOfParam>` for StringRefs.
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 5a7037af63d2..819badc8b0f4 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2673,6 +2673,8 @@ class TypeDef<Dialect dialect, string name,
class AttrOrTypeParameter<string type, string desc> {
// Custom memory allocation code for storage constructor.
code allocator = ?;
+ // Custom comparator used to compare two instances for equality.
+ code comparator = ?;
// The C++ type of this parameter.
string cppType = type;
// One-line human-readable description of the argument.
@@ -2689,6 +2691,12 @@ class StringRefParameter<string desc = ""> :
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
}
+// For APFloats, which require comparison.
+class APFloatParameter<string desc> :
+ AttrOrTypeParameter<"::llvm::APFloat", desc> {
+ let comparator = "$_lhs.bitwiseIsEqual($_rhs)";
+}
+
// For standard ArrayRefs, which require allocation.
class ArrayRefParameter<string arrayOf, string desc = ""> :
AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 84a5a04c909c..5271fbae0eea 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -183,6 +183,9 @@ class AttrOrTypeParameter {
// If specified, get the custom allocator code for this parameter.
Optional<StringRef> getAllocator() const;
+ // If specified, get the custom comparator code for this parameter.
+ Optional<StringRef> getComparator() const;
+
// Get the C++ type of this parameter.
StringRef getCppType() const;
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 037dc4d40cc5..eea03015d329 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -177,22 +177,18 @@ Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
llvm::Init *parameterType = def->getArg(index);
if (isa<llvm::StringInit>(parameterType))
return Optional<StringRef>();
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
+ return param->getDef()->getValueAsOptionalString("allocator");
+ llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+ "defs which inherit from AttrOrTypeParameter\n");
+}
- if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
- llvm::RecordVal *code = param->getDef()->getValue("allocator");
- if (!code)
- return Optional<StringRef>();
- if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(code->getValue()))
- return ci->getValue();
- if (isa<llvm::UnsetInit>(code->getValue()))
- return Optional<StringRef>();
-
- llvm::PrintFatalError(
- param->getDef()->getLoc(),
- "Record `" + def->getArgName(index)->getValue() +
- "', field `printer' does not have a code initializer!");
- }
-
+Optional<StringRef> AttrOrTypeParameter::getComparator() const {
+ llvm::Init *parameterType = def->getArg(index);
+ if (isa<llvm::StringInit>(parameterType))
+ return Optional<StringRef>();
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
+ return param->getDef()->getValueAsOptionalString("comparator");
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter\n");
}
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 34cdcab7246e..252b9175b05d 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -53,7 +53,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
ins
"int":$widthOfSomething,
"::mlir::test::SimpleTypeA": $exampleTdType,
- "SomeCppStruct": $exampleCppType,
+ APFloatParameter<"">: $apFloat,
ArrayRefParameter<"int", "Matrix dimensions">:$dims,
AttributeSelfTypeParameter<"">:$inner
);
@@ -61,8 +61,8 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
let genVerifyDecl = 1;
// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
-// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
-// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
// DECL: return ::llvm::StringLiteral("cmpnd_a");
// DECL: }
@@ -71,7 +71,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
-// DECL: SomeCppStruct getExampleCppType() const;
+// DECL: ::llvm::APFloat getApFloat() const;
// Check that AttributeSelfTypeParameter is handled properly.
// DEF-LABEL: struct CompoundAAttrStorage
@@ -79,11 +79,21 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DEF-NEXT: : ::mlir::AttributeStorage(inner),
// DEF: bool operator==(const KeyTy &key) const {
-// DEF-NEXT: return key == KeyTy(widthOfSomething, exampleTdType, exampleCppType, dims, getType());
+// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key)))
+// DEF-NEXT: return false;
+// DEF-NEXT: if (!(exampleTdType == std::get<1>(key)))
+// DEF-NEXT: return false;
+// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key))))
+// DEF-NEXT: return false;
+// DEF-NEXT: if (!(dims == std::get<3>(key)))
+// DEF-NEXT: return false;
+// DEF-NEXT: if (!(getType() == std::get<4>(key)))
+// DEF-NEXT: return false;
+// DEF-NEXT: return true;
// DEF: static CompoundAAttrStorage *construct
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
-// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, exampleCppType, dims, inner);
+// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); }
}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2c3caaeeea25..636d4f8b51ef 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -432,22 +432,16 @@ static ::mlir::LogicalResult generated{0}Printer(
/// {1}: Storage class c++ name.
/// {2}: Parameters parameters.
/// {3}: Parameter initializer string.
-/// {4}: Parameter name list.
-/// {5}: Parameter types.
-/// {6}: The name of the base value type, e.g. Attribute or Type.
+/// {4}: Parameter types.
+/// {5}: The name of the base value type, e.g. Attribute or Type.
static const char *const defStorageClassBeginStr = R"(
namespace {0} {{
- struct {1} : public ::mlir::{6}Storage {{
+ struct {1} : public ::mlir::{5}Storage {{
{1} ({2})
: {3} {{ }
/// The hash key is a tuple of the parameter types.
- using KeyTy = std::tuple<{5}>;
-
- /// Define the comparison function for the key type.
- bool operator==(const KeyTy &key) const {{
- return key == KeyTy({4});
- }
+ using KeyTy = std::tuple<{4}>;
)";
/// The storage class' constructor template.
@@ -555,23 +549,34 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
});
}
- // Construct the parameter list that is used when a concrete instance of the
- // storage exists.
- auto nonStaticParameterNames = llvm::map_range(params, [](const auto ¶m) {
- return isa<AttributeSelfTypeParameter>(param) ? "getType()"
- : param.getName();
- });
-
- // 1) Emit most of the storage class up until the hashKey body.
+ // * Emit most of the storage class up until the hashKey body.
os << formatv(
defStorageClassBeginStr, def.getStorageNamespace(),
def.getStorageClassName(),
ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs,
params, /*prependComma=*/false),
- paramInitializer, llvm::join(nonStaticParameterNames, ", "),
- parameterTypeList, valueType);
+ paramInitializer, parameterTypeList, valueType);
+
+ // * Emit the comparison method.
+ os << " bool operator==(const KeyTy &key) const {\n";
+ for (auto it : llvm::enumerate(params)) {
+ os << " if (!(";
+
+ // Build the comparator context.
+ bool isSelfType = isa<AttributeSelfTypeParameter>(it.value());
+ FmtContext context;
+ context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName())
+ .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)");
+
+ // Use the parameter specified comparator if possible, otherwise default to
+ // operator==.
+ Optional<StringRef> comparator = it.value().getComparator();
+ os << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &context);
+ os << "))\n return false;\n";
+ }
+ os << " return true;\n }\n";
- // 2) Emit the haskKey method.
+ // * Emit the haskKey method.
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
// Extract each parameter from the key.
@@ -581,7 +586,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
[&](unsigned it) { os << "std::get<" << it << ">(key)"; });
os << ");\n }\n";
- // 3) Emit the construct method.
+ // * Emit the construct method.
// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
@@ -611,7 +616,7 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
llvm::join(parameterNames, ", "));
}
- // 4) Emit the parameters as storage class members.
+ // * Emit the parameters as storage class members.
for (const AttrOrTypeParameter ¶meter : params) {
// Attribute value types are not stored as fields in the storage.
if (!isa<AttributeSelfTypeParameter>(parameter))
More information about the Mlir-commits
mailing list