[Mlir-commits] [mlir] f402e68 - [MLIR] ODS TypeDefs: getChecked() and internal enhancements
John Demme
llvmlistbot at llvm.org
Sun Oct 18 18:10:54 PDT 2020
Author: John Demme
Date: 2020-10-19T01:10:05Z
New Revision: f402e682d0ef5598eeffc9a21a691b03e602ff58
URL: https://github.com/llvm/llvm-project/commit/f402e682d0ef5598eeffc9a21a691b03e602ff58
DIFF: https://github.com/llvm/llvm-project/commit/f402e682d0ef5598eeffc9a21a691b03e602ff58.diff
LOG: [MLIR] ODS TypeDefs: getChecked() and internal enhancements
Have the ODS TypeDef generator write the getChecked() definition.
Also add to TypeParamCommaFormatter a `JustParams` format and
refactor around that.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D89438
Added:
Modified:
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/TypeDefGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index cfab6985e87a..d755d91bb5d6 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -75,7 +75,8 @@ def IntegerType : Test_Type<"TestInteger"> {
int width;
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
- return get(ctxt, signedness, width);
+ Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
+ return getChecked(loc, signedness, width);
}];
// Any extra code one wants in the type's class declaration.
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 5ba3fbcc7115..4fc7f3282bd7 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -51,7 +51,7 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
// DECL-LABEL: class CompoundAType: public ::mlir::Type
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
-// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
+// DECL: static ::mlir::Type getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 6bf6e04ef2cf..4473f629f3f1 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -78,25 +78,25 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
/// [...]".
TypeNamePairs,
- /// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
- /// [...]".
- TypeNamePairsPrependComma,
-
/// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
- TypeNameInitializer
+ TypeNameInitializer,
+
+ /// Emit "param1Name, param2Name, [...]".
+ JustParams,
};
- TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params)
- : emitFormat(emitFormat), params(params) {}
+ TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params,
+ bool prependComma = true)
+ : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
/// llvm::formatv will call this function when using an instance as a
/// replacement value.
void format(raw_ostream &os, StringRef options) override {
- if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma)
+ if (params.size() && prependComma)
os << ", ";
+
switch (emitFormat) {
case EmitFormat::TypeNamePairs:
- case EmitFormat::TypeNamePairsPrependComma:
interleaveComma(params, os,
[&](const TypeParameter &p) { emitTypeNamePair(p, os); });
break;
@@ -105,6 +105,10 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
emitTypeNameInitializer(p, os);
});
break;
+ case EmitFormat::JustParams:
+ interleaveComma(params, os,
+ [&](const TypeParameter &p) { os << p.getName(); });
+ break;
}
}
@@ -120,6 +124,7 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
EmitFormat emitFormat;
ArrayRef<TypeParameter> params;
+ bool prependComma;
};
} // end anonymous namespace
@@ -168,10 +173,9 @@ static const char *const typeDefParsePrint = R"(
/// The code block for the verifyConstructionInvariants and getChecked.
///
/// {0}: List of parameters, parameters style.
-/// {1}: C++ type class name.
static const char *const typeDefDeclVerifyStr = R"(
static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
- static {1} getChecked(Location loc{0});
+ static ::mlir::Type getChecked(Location loc{0});
)";
/// Generate the declaration for the given typeDef class.
@@ -194,14 +198,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
os << *extraDecl << "\n";
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params);
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
- os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
- typeDef.getCppClassName());
+ os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
// Emit the mnenomic, if specified.
if (auto mnenomic = typeDef.getMnemonic()) {
@@ -317,6 +320,17 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
}
)";
+/// The code block for the getChecked definition.
+///
+/// {0}: List of parameters, parameters style.
+/// {1}: C++ type class name.
+/// {2}: Comma separated list of parameter names.
+static const char *const typeDefDefGetCheckeStr = R"(
+ ::mlir::Type {1}::getChecked(Location loc{0}) {{
+ return Base::getChecked(loc{2});
+ }
+)";
+
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
@@ -355,27 +369,28 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
auto parameterTypeList = join(parameterTypes, ", ");
// 1) Emit most of the storage class up until the hashKey body.
- os << formatv(
- typeDefStorageClassBegin, typeDef.getStorageNamespace(),
- typeDef.getStorageClassName(),
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
- TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
- parameterList, parameterTypeList);
+ os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(),
+ typeDef.getStorageClassName(),
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
+ parameters, /*prependComma=*/false),
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
+ parameters, /*prependComma=*/false),
+ parameterList, parameterTypeList);
// 2) Emit the haskKey method.
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
// Extract each parameter from the key.
for (size_t i = 0, e = parameters.size(); i < e; ++i)
- os << formatv(" const auto &{0} = std::get<{1}>(key);\n",
- parameters[i].getName(), i);
+ os << llvm::formatv(" const auto &{0} = std::get<{1}>(key);\n",
+ parameters[i].getName(), i);
// Then combine them all. This requires all the parameters types to have a
// hash_value defined.
- os << " return ::llvm::hash_combine(";
- interleaveComma(parameterNames, os);
- os << ");\n";
- os << " }\n";
+ os << llvm::formatv(
+ " return ::llvm::hash_combine({0});\n }\n",
+ TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+ parameters, /* prependComma */ false));
// 3) Emit the construct method.
if (typeDef.hasStorageCustomConstructor())
@@ -462,14 +477,12 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
- " return Base::get(ctxt",
+ " return Base::get(ctxt{2});\n}\n",
typeDef.getCppClassName(),
TypeParamCommaFormatter(
- TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
- parameters));
- for (TypeParameter ¶m : parameters)
- os << ", " << param.getName();
- os << ");\n}\n";
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+ TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+ parameters));
// Emit the parameter accessors.
if (typeDef.genAccessors())
@@ -481,6 +494,17 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
typeDef.getCppClassName());
}
+ // Generate getChecked() method.
+ if (typeDef.genVerifyInvariantsDecl()) {
+ os << llvm::formatv(
+ typeDefDefGetCheckeStr,
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+ typeDef.getCppClassName(),
+ TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+ parameters));
+ }
+
// If mnemonic is specified maybe print definitions for the parser and printer
// code, if they're specified.
if (typeDef.getMnemonic())
More information about the Mlir-commits
mailing list