[llvm-branch-commits] [mlir] 41d919a - [mlir][TypeDefGen] Remove the need to define parser/printer for singleton types

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 6 15:05:16 PST 2021


Author: River Riddle
Date: 2021-01-06T15:00:14-08:00
New Revision: 41d919aa29468ac072755b8449b8a38ff26f6979

URL: https://github.com/llvm/llvm-project/commit/41d919aa29468ac072755b8449b8a38ff26f6979
DIFF: https://github.com/llvm/llvm-project/commit/41d919aa29468ac072755b8449b8a38ff26f6979.diff

LOG: [mlir][TypeDefGen] Remove the need to define parser/printer for singleton types

This allows for singleton types without an explicit parser/printer to simply use
the mnemonic as the assembly format, removing the need for these types to provide the parser/printer
fields.

Differential Revision: https://reviews.llvm.org/D94194

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/tools/mlir-tblgen/TypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 75fffa11cb21..80927dff62c2 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -21,9 +21,6 @@ class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
 
 def SimpleTypeA : Test_Type<"SimpleA"> {
   let mnemonic = "smpla";
-
-  let printer = [{ $_printer << "smpla"; }];
-  let parser = [{ return get($_ctxt); }];
 }
 
 // A more complex parameterized type.

diff  --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 8fdb5f4feeaf..20168168bc8d 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -537,12 +537,21 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
   os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* "
         "ctxt, "
         "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
-  for (const TypeDef &type : types)
-    if (type.getMnemonic())
+  for (const TypeDef &type : types) {
+    if (type.getMnemonic()) {
       os << formatv("  if (mnemonic == {0}::{1}::getMnemonic()) return "
-                    "{0}::{1}::parse(ctxt, parser);\n",
+                    "{0}::{1}::",
                     type.getDialect().getCppNamespace(),
                     type.getCppClassName());
+
+      // If the type has no parameters and no parser code, just invoke a normal
+      // `get`.
+      if (type.getNumParameters() == 0 && !type.getParserCode())
+        os << "get(ctxt);\n";
+      else
+        os << "parse(ctxt, parser);\n";
+    }
+  }
   os << "  return ::mlir::Type();\n";
   os << "}\n\n";
 
@@ -551,17 +560,26 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
   os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
         "type, "
         "::mlir::DialectAsmPrinter& printer) {\n"
-     << "  ::mlir::LogicalResult found = ::mlir::success();\n"
-     << "  ::llvm::TypeSwitch<::mlir::Type>(type)\n";
-  for (const TypeDef &type : types)
-    if (type.getMnemonic())
-      os << formatv("    .Case<{0}::{1}>([&](::mlir::Type t) {{ "
-                    "t.dyn_cast<{0}::{1}>().print(printer); })\n",
-                    type.getDialect().getCppNamespace(),
-                    type.getCppClassName());
-  os << "    .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
-        "});\n"
-     << "  return found;\n"
+     << "  return ::llvm::TypeSwitch<::mlir::Type, "
+        "::mlir::LogicalResult>(type)\n";
+  for (const TypeDef &type : types) {
+    if (Optional<StringRef> mnemonic = type.getMnemonic()) {
+      StringRef cppNamespace = type.getDialect().getCppNamespace();
+      StringRef cppClassName = type.getCppClassName();
+      os << formatv("    .Case<{0}::{1}>([&]({0}::{1} t) {{\n      ",
+                    cppNamespace, cppClassName);
+
+      // If the type has no parameters and no printer code, just print the
+      // mnemonic.
+      if (type.getNumParameters() == 0 && !type.getPrinterCode())
+        os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace,
+                      cppClassName);
+      else
+        os << "t.print(printer);";
+      os << "\n      return ::mlir::success();\n    })\n";
+    }
+  }
+  os << "    .Default([](::mlir::Type) { return ::mlir::failure(); });\n"
      << "}\n\n";
 }
 


        


More information about the llvm-branch-commits mailing list