[Mlir-commits] [mlir] f402e68 - [MLIR] ODS TypeDefs: getChecked() and internal enhancements

John Demme llvmlistbot at llvm.org
Sun Oct 18 18:10:54 PDT 2020


Author: John Demme
Date: 2020-10-19T01:10:05Z
New Revision: f402e682d0ef5598eeffc9a21a691b03e602ff58

URL: https://github.com/llvm/llvm-project/commit/f402e682d0ef5598eeffc9a21a691b03e602ff58
DIFF: https://github.com/llvm/llvm-project/commit/f402e682d0ef5598eeffc9a21a691b03e602ff58.diff

LOG: [MLIR] ODS TypeDefs: getChecked() and internal enhancements

Have the ODS TypeDef generator write the getChecked() definition.
Also add to TypeParamCommaFormatter a `JustParams` format and
refactor around that.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D89438

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/mlir-tblgen/typedefs.td
    mlir/tools/mlir-tblgen/TypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index cfab6985e87a..d755d91bb5d6 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -75,7 +75,8 @@ def IntegerType : Test_Type<"TestInteger"> {
     int width;
     if ($_parser.parseInteger(width)) return Type();
     if ($_parser.parseGreater()) return Type();
-    return get(ctxt, signedness, width);
+    Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
+    return getChecked(loc, signedness, width);
   }];
 
   // Any extra code one wants in the type's class declaration.

diff  --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 5ba3fbcc7115..4fc7f3282bd7 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -51,7 +51,7 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
 
 // DECL-LABEL: class CompoundAType: public ::mlir::Type
 // DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
-// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
+// DECL: static ::mlir::Type getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
 // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
 // DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
 // DECL: void print(::mlir::DialectAsmPrinter& printer) const;

diff  --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 6bf6e04ef2cf..4473f629f3f1 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -78,25 +78,25 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
     /// [...]".
     TypeNamePairs,
 
-    /// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
-    /// [...]".
-    TypeNamePairsPrependComma,
-
     /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
-    TypeNameInitializer
+    TypeNameInitializer,
+
+    /// Emit "param1Name, param2Name, [...]".
+    JustParams,
   };
 
-  TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params)
-      : emitFormat(emitFormat), params(params) {}
+  TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params,
+                          bool prependComma = true)
+      : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
 
   /// llvm::formatv will call this function when using an instance as a
   /// replacement value.
   void format(raw_ostream &os, StringRef options) override {
-    if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma)
+    if (params.size() && prependComma)
       os << ", ";
+
     switch (emitFormat) {
     case EmitFormat::TypeNamePairs:
-    case EmitFormat::TypeNamePairsPrependComma:
       interleaveComma(params, os,
                       [&](const TypeParameter &p) { emitTypeNamePair(p, os); });
       break;
@@ -105,6 +105,10 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
         emitTypeNameInitializer(p, os);
       });
       break;
+    case EmitFormat::JustParams:
+      interleaveComma(params, os,
+                      [&](const TypeParameter &p) { os << p.getName(); });
+      break;
     }
   }
 
@@ -120,6 +124,7 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
 
   EmitFormat emitFormat;
   ArrayRef<TypeParameter> params;
+  bool prependComma;
 };
 
 } // end anonymous namespace
@@ -168,10 +173,9 @@ static const char *const typeDefParsePrint = R"(
 /// The code block for the verifyConstructionInvariants and getChecked.
 ///
 /// {0}: List of parameters, parameters style.
-/// {1}: C++ type class name.
 static const char *const typeDefDeclVerifyStr = R"(
     static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
-    static {1} getChecked(Location loc{0});
+    static ::mlir::Type getChecked(Location loc{0});
 )";
 
 /// Generate the declaration for the given typeDef class.
@@ -194,14 +198,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
     os << *extraDecl << "\n";
 
   TypeParamCommaFormatter emitTypeNamePairsAfterComma(
-      TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params);
+      TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
   os << llvm::formatv("    static {0} get(::mlir::MLIRContext* ctxt{1});\n",
                       typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
 
   // Emit the verify invariants declaration.
   if (typeDef.genVerifyInvariantsDecl())
-    os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
-                        typeDef.getCppClassName());
+    os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
 
   // Emit the mnenomic, if specified.
   if (auto mnenomic = typeDef.getMnemonic()) {
@@ -317,6 +320,17 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
     }
 )";
 
+/// The code block for the getChecked definition.
+///
+/// {0}: List of parameters, parameters style.
+/// {1}: C++ type class name.
+/// {2}: Comma separated list of parameter names.
+static const char *const typeDefDefGetCheckeStr = R"(
+    ::mlir::Type {1}::getChecked(Location loc{0}) {{
+      return Base::getChecked(loc{2});
+    }
+)";
+
 /// Use tgfmt to emit custom allocation code for each parameter, if necessary.
 static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
   SmallVector<TypeParameter, 4> parameters;
@@ -355,27 +369,28 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
   auto parameterTypeList = join(parameterTypes, ", ");
 
   // 1) Emit most of the storage class up until the hashKey body.
-  os << formatv(
-      typeDefStorageClassBegin, typeDef.getStorageNamespace(),
-      typeDef.getStorageClassName(),
-      TypeParamCommaFormatter(
-          TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
-      TypeParamCommaFormatter(
-          TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
-      parameterList, parameterTypeList);
+  os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(),
+                typeDef.getStorageClassName(),
+                TypeParamCommaFormatter(
+                    TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
+                    parameters, /*prependComma=*/false),
+                TypeParamCommaFormatter(
+                    TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
+                    parameters, /*prependComma=*/false),
+                parameterList, parameterTypeList);
 
   // 2) Emit the haskKey method.
   os << "  static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
   // Extract each parameter from the key.
   for (size_t i = 0, e = parameters.size(); i < e; ++i)
-    os << formatv("      const auto &{0} = std::get<{1}>(key);\n",
-                  parameters[i].getName(), i);
+    os << llvm::formatv("      const auto &{0} = std::get<{1}>(key);\n",
+                        parameters[i].getName(), i);
   // Then combine them all. This requires all the parameters types to have a
   // hash_value defined.
-  os << "      return ::llvm::hash_combine(";
-  interleaveComma(parameterNames, os);
-  os << ");\n";
-  os << "    }\n";
+  os << llvm::formatv(
+      "      return ::llvm::hash_combine({0});\n    }\n",
+      TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+                              parameters, /* prependComma */ false));
 
   // 3) Emit the construct method.
   if (typeDef.hasStorageCustomConstructor())
@@ -462,14 +477,12 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
 
   os << llvm::formatv(
       "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
-      "  return Base::get(ctxt",
+      "  return Base::get(ctxt{2});\n}\n",
       typeDef.getCppClassName(),
       TypeParamCommaFormatter(
-          TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
-          parameters));
-  for (TypeParameter &param : parameters)
-    os << ", " << param.getName();
-  os << ");\n}\n";
+          TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+      TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+                              parameters));
 
   // Emit the parameter accessors.
   if (typeDef.genAccessors())
@@ -481,6 +494,17 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
                     typeDef.getCppClassName());
     }
 
+  // Generate getChecked() method.
+  if (typeDef.genVerifyInvariantsDecl()) {
+    os << llvm::formatv(
+        typeDefDefGetCheckeStr,
+        TypeParamCommaFormatter(
+            TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+        typeDef.getCppClassName(),
+        TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
+                                parameters));
+  }
+
   // If mnemonic is specified maybe print definitions for the parser and printer
   // code, if they're specified.
   if (typeDef.getMnemonic())


        


More information about the Mlir-commits mailing list