[Mlir-commits] [mlir] 32c5578 - [MLIR] Split autogenerated pass declarations & C++ controllable pass options

Michele Scuttari llvmlistbot at llvm.org
Wed Aug 24 01:01:42 PDT 2022


Author: Michele Scuttari
Date: 2022-08-24T10:01:08+02:00
New Revision: 32c5578bcddf92a94947b390b4d2862bbb624622

URL: https://github.com/llvm/llvm-project/commit/32c5578bcddf92a94947b390b4d2862bbb624622
DIFF: https://github.com/llvm/llvm-project/commit/32c5578bcddf92a94947b390b4d2862bbb624622.diff

LOG: [MLIR] Split autogenerated pass declarations & C++ controllable pass options

The pass tablegen backend has been reworked to remove the monolithic nature of the autogenerated declarations.
The pass public header can be generated with the -gen-pass-decls option. It contains options structs and registrations: the inclusion of options structs can be controlled individually for each pass by defining the GEN_PASS_DECL_PASSNAME macro; the declaration of the registrations have been kept together and can still be included by defining the GEN_PASS_REGISTRATION macro.
The private code used for the pass implementation (i.e. the pass base class and the constructors definitions, if missing from tablegen) can be generated with the -gen-pass-defs option. Similarly to the declarations file, the definitions of each pass can be enabled by defining the GEN_PASS_DEF_PASNAME variable.
While doing so, the pass base class has been enriched to also accept a the aformentioned struct of options and copy them to the actual pass options, thus allowing each pass to also be configurable within C++ and not only through command line.

Reviewed By: rriddle, mehdi_amini, Mogball, jpienaar

Differential Revision: https://reviews.llvm.org/D131839

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/Pass/PassBase.td
    mlir/lib/TableGen/Pass.cpp
    mlir/tools/mlir-tblgen/PassCAPIGen.cpp
    mlir/tools/mlir-tblgen/PassGen.cpp
    mlir/unittests/TableGen/CMakeLists.txt
    mlir/unittests/TableGen/PassGenTest.cpp
    mlir/unittests/TableGen/passes.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 1c5e99fa724d2..a09ca328c4d02 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -828,16 +828,18 @@ def MyPass : Pass<"my-pass", "ModuleOp"> {
 }
 ```
 
-Using the `gen-pass-decls` generator, we can generate most of the boilerplate
-above automatically. This generator takes as an input a `-name` parameter, that
+Using the `gen-pass-decls` and `gen-pass-defs` generators, we can generate most
+of the boilerplate above automatically.
+
+The `gen-pass-decls` generator takes as an input a `-name` parameter, that
 provides a tag for the group of passes that are being generated. This generator
-produces two chunks of output:
+produces code with two purposes:
 
-The first is a code block for registering the declarative passes with the global
-registry. For each pass, the generator produces a `registerFooPass` where `Foo`
-is the name of the definition specified in tablegen. It also generates a
-`registerGroupPasses`, where `Group` is the tag provided via the `-name` input
-parameter, that registers all of the passes present.
+The first is to register the declared passes with the global registry. For
+each pass, the generator produces a `registerPassName` where
+`PassName` is the name of the definition specified in tablegen. It also
+generates a `registerGroupPasses`, where `Group` is the tag provided via the
+`-name` input parameter, that registers all of the passes present.
 
 ```c++
 // gen-pass-decls -name="Example"
@@ -850,19 +852,61 @@ void registerMyPasses() {
   registerExamplePasses();
 
   // Register `MyPass` specifically.
-  registerMyPassPass();
+  registerMyPass();
 }
 ```
 
-The second is a base class for each of the passes, containing most of the boiler
+The second is to provide a way to configure the pass options. These classes are
+named in the form of `MyPassOptions`, where `MyPass` is the name of the pass
+definition in tablegen. The configurable parameters reflect the options
+declared in the tablegen file. Differently from the registration hooks, these
+classes can be enabled on a per-pass basis by defining the
+`GEN_PASS_DECL_PASSNAME` macro, where `PASSNAME` is the uppercase version of
+the name specified in tablegen.
+
+```c++
+// .h.inc
+
+#ifdef GEN_PASS_DECL_MYPASS
+
+struct MyPassOptions {
+    bool option = true;
+    ::llvm::ArrayRef<int64_t> listOption;
+};
+
+#undef GEN_PASS_DECL_MYPASS
+#endif // GEN_PASS_DECL_MYPASS
+```
+
+If the `constructor` field has not been specified in the tablegen declaration,
+then autogenerated file will also contain the declarations of the default
+constructors.
+
+```c++
+// .h.inc
+
+#ifdef GEN_PASS_DECL_MYPASS
+...
+
+std::unique_ptr<::mlir::Pass> createMyPass();
+std::unique_ptr<::mlir::Pass> createMyPass(const MyPassOptions &options);
+
+#undef GEN_PASS_DECL_MYPASS
+#endif // GEN_PASS_DECL_MYPASS
+```
+
+The `gen-pass-defs` generator produces the definitions to be used for the pass
+implementation.
+
+It generates a base class for each of the passes, containing most of the boiler
 plate related to pass definitions. These classes are named in the form of
 `MyPassBase`, where `MyPass` is the name of the pass definition in tablegen. We
 can update the original C++ pass definition as so:
 
 ```c++
 /// Include the generated base pass class definitions.
-#define GEN_PASS_CLASSES
-#include "Passes.h.inc"
+#define GEN_PASS_DEF_MYPASS
+#include "Passes.cpp.inc"
 
 /// Define the main class as deriving from the generated base class.
 struct MyPass : MyPassBase<MyPass> {
@@ -874,13 +918,16 @@ struct MyPass : MyPassBase<MyPass> {
   /// The definitions of the options and statistics are now generated within
   /// the base class, but are accessible in the same way.
 };
-
-/// Expose this pass to the outside world.
-std::unique_ptr<Pass> foo::createMyPass() {
-  return std::make_unique<MyPass>();
-}
 ```
 
+Similarly to the previous generator, the definitions can be enabled on a
+per-pass basis by defining the appropriate preprocessor `GEN_PASS_DEF_PASSNAME`
+macro, with `PASSNAME` equal to the uppercase version of the name of the pass
+definition in tablegen.
+If the `constructor` field has not been specified in tablegen, then the default
+constructors are also defined and expect the name of the actual pass class to
+be equal to the name defined in tablegen.
+
 Using the `gen-pass-doc` generator, markdown documentation for each of the
 passes can be generated. See [Passes.md](Passes.md) for example output of real
 MLIR passes.

diff  --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td
index 9b1903acc3019..e37f9735e2241 100644
--- a/mlir/include/mlir/Pass/PassBase.td
+++ b/mlir/include/mlir/Pass/PassBase.td
@@ -76,7 +76,10 @@ class PassBase<string passArg, string base> {
   string description = "";
 
   // A C++ constructor call to create an instance of this pass.
-  code constructor = [{}];
+  // If empty, the default constructor declarations and definitions
+  // 'createPassName()' and 'createPassName(const PassNameOptions &options)'
+  // will be generated and the former will be used for the pass instantiation.
+  code constructor = "";
 
   // A list of dialects this pass may produce entities in.
   list<string> dependentDialects = [];

diff  --git a/mlir/lib/TableGen/Pass.cpp b/mlir/lib/TableGen/Pass.cpp
index 84b3f01d1255c..e9c65e8fbd149 100644
--- a/mlir/lib/TableGen/Pass.cpp
+++ b/mlir/lib/TableGen/Pass.cpp
@@ -90,6 +90,7 @@ StringRef Pass::getDescription() const {
 StringRef Pass::getConstructor() const {
   return def->getValueAsString("constructor");
 }
+
 ArrayRef<StringRef> Pass::getDependentDialects() const {
   return dependentDialects;
 }

diff  --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
index 4fa1150957c5d..34368635c171f 100644
--- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
@@ -97,8 +97,15 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
   for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
     Pass pass(def);
     StringRef defName = pass.getDef()->getName();
-    os << llvm::formatv(passCreateDef, groupName, defName,
-                        pass.getConstructor());
+
+    std::string constructorCall;
+    if (StringRef constructor = pass.getConstructor(); !constructor.empty())
+      constructorCall = constructor.str();
+    else
+      constructorCall =
+          llvm::formatv("create{0}Pass()", pass.getDef()->getName()).str();
+
+    os << llvm::formatv(passCreateDef, groupName, defName, constructorCall);
   }
   return false;
 }

diff  --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index 1c734aba192a5..f4a9d5b9993bc 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -27,6 +27,161 @@ static llvm::cl::opt<std::string>
     groupName("name", llvm::cl::desc("The name of this group of passes"),
               llvm::cl::cat(passGenCat));
 
+static void emitOldPassDecl(const Pass &pass, raw_ostream &os);
+
+/// Extract the list of passes from the TableGen records.
+static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
+  std::vector<Pass> passes;
+
+  for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
+    passes.emplace_back(def);
+
+  return passes;
+}
+
+const char *const passHeader = R"(
+//===----------------------------------------------------------------------===//
+// {0}
+//===----------------------------------------------------------------------===//
+)";
+
+//===----------------------------------------------------------------------===//
+// GEN: Pass registration generation
+//===----------------------------------------------------------------------===//
+
+/// The code snippet used to generate a pass registration.
+///
+/// {0}: The def name of the pass record.
+/// {1}: The pass constructor call.
+const char *const passRegistrationCode = R"(
+//===----------------------------------------------------------------------===//
+// {0} Registration
+//===----------------------------------------------------------------------===//
+
+inline void register{0}() {{
+  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
+    return {1};
+  });
+}
+
+// Old registration code, kept for temporary backwards compatibility.
+inline void register{0}Pass() {{
+  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
+    return {1};
+  });
+}
+)";
+
+/// The code snippet used to generate a function to register all passes in a
+/// group.
+///
+/// {0}: The name of the pass group.
+const char *const passGroupRegistrationCode = R"(
+//===----------------------------------------------------------------------===//
+// {0} Registration
+//===----------------------------------------------------------------------===//
+
+inline void register{0}Passes() {{
+)";
+
+/// Emits the definition of the struct to be used to control the pass options.
+static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
+  StringRef passName = pass.getDef()->getName();
+  ArrayRef<PassOption> options = pass.getOptions();
+
+  // Emit the struct only if the pass has at least one option.
+  if (options.empty())
+    return;
+
+  os << llvm::formatv("struct {0}Options {{\n", passName);
+
+  for (const PassOption &opt : options) {
+    std::string type = opt.getType().str();
+
+    if (opt.isListOption())
+      type = "::llvm::ArrayRef<" + type + ">";
+
+    os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
+
+    if (Optional<StringRef> defaultVal = opt.getDefaultValue())
+      os << " = " << defaultVal;
+
+    os << ";\n";
+  }
+
+  os << "};\n";
+}
+
+/// Emit the code to be included in the public header of the pass.
+static void emitPassDecls(const Pass &pass, raw_ostream &os) {
+  StringRef passName = pass.getDef()->getName();
+  std::string enableVarName = "GEN_PASS_DECL_" + passName.upper();
+
+  os << "#ifdef " << enableVarName << "\n";
+  os << llvm::formatv(passHeader, passName);
+
+  emitPassOptionsStruct(pass, os);
+
+  if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
+    // Default constructor declaration.
+    os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
+
+    // Declaration of the constructor with options.
+    if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
+      os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const "
+                          "{0}Options &options);\n",
+                          passName);
+  }
+
+  os << "#undef " << enableVarName << "\n";
+  os << "#endif // " << enableVarName << "\n";
+}
+
+/// Emit the code for registering each of the given passes with the global
+/// PassRegistry.
+static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
+  os << "#ifdef GEN_PASS_REGISTRATION\n";
+
+  for (const Pass &pass : passes) {
+    std::string constructorCall;
+    if (StringRef constructor = pass.getConstructor(); !constructor.empty())
+      constructorCall = constructor.str();
+    else
+      constructorCall =
+          llvm::formatv("create{0}()", pass.getDef()->getName()).str();
+
+    os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
+                        constructorCall);
+  }
+
+  os << llvm::formatv(passGroupRegistrationCode, groupName);
+
+  for (const Pass &pass : passes)
+    os << "  register" << pass.getDef()->getName() << "();\n";
+
+  os << "}\n";
+  os << "#undef GEN_PASS_REGISTRATION\n";
+  os << "#endif // GEN_PASS_REGISTRATION\n";
+}
+
+static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
+  std::vector<Pass> passes = getPasses(recordKeeper);
+  os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
+
+  for (const Pass &pass : passes)
+    emitPassDecls(pass, os);
+
+  emitRegistrations(passes, os);
+
+  // TODO drop old pass declarations
+  // Emit the old code until all the passes have switched to the new design.
+  os << "#ifdef GEN_PASS_CLASSES\n";
+  for (const Pass &pass : passes)
+    emitOldPassDecl(pass, os);
+  os << "#undef GEN_PASS_CLASSES\n";
+  os << "#endif // GEN_PASS_CLASSES\n";
+}
+
 //===----------------------------------------------------------------------===//
 // GEN: Pass base class generation
 //===----------------------------------------------------------------------===//
@@ -38,10 +193,6 @@ static llvm::cl::opt<std::string>
 /// {2): The command line argument for the pass.
 /// {3}: The dependent dialects registration.
 const char *const passDeclBegin = R"(
-//===----------------------------------------------------------------------===//
-// {0}
-//===----------------------------------------------------------------------===//
-
 template <typename DerivedT>
 class {0}Base : public {1} {
 public:
@@ -84,7 +235,6 @@ class {0}Base : public {1} {
   /// library.
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
 
-protected:
 )";
 
 /// Registration for a single dependent dialect, to be inserted for each
@@ -93,6 +243,18 @@ const char *const dialectRegistrationTemplate = R"(
   registry.insert<{0}>();
 )";
 
+const char *const friendDefaultConstructorTemplate = R"(
+  friend std::unique_ptr<::mlir::Pass> create{0}() {{
+    return std::make_unique<DerivedT>();
+  }
+)";
+
+const char *const friendDefaultConstructorWithOptionsTemplate = R"(
+  friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
+    return std::make_unique<DerivedT>(options);
+  }
+)";
+
 /// Emit the declarations for each of the pass options.
 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
   for (const PassOption &opt : pass.getOptions()) {
@@ -119,8 +281,14 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
   }
 }
 
-static void emitPassDecl(const Pass &pass, raw_ostream &os) {
-  StringRef defName = pass.getDef()->getName();
+/// Emit the code to be used in the implementation of the pass.
+static void emitPassDefs(const Pass &pass, raw_ostream &os) {
+  StringRef passName = pass.getDef()->getName();
+  std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
+
+  os << "#ifdef " << enableVarName << "\n";
+  os << llvm::formatv(passHeader, passName);
+
   std::string dependentDialectRegistrations;
   {
     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
@@ -128,90 +296,129 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
       dialectsOs << llvm::formatv(dialectRegistrationTemplate,
                                   dependentDialect);
   }
-  os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
+
+  os << llvm::formatv(passDeclBegin, passName, pass.getBaseClass(),
                       pass.getArgument(), pass.getSummary(),
                       dependentDialectRegistrations);
+
+  if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
+    os.indent(2) << llvm::formatv(
+        "{0}Base(const {0}Options &options) : {0}Base() {{\n", passName);
+
+    for (const PassOption &opt : pass.getOptions())
+      os.indent(4) << llvm::formatv("{0} = options.{0};\n",
+                                    opt.getCppVariableName());
+
+    os.indent(2) << "}\n";
+  }
+
+  // Protected content
+  os << "protected:\n";
   emitPassOptionDecls(pass, os);
   emitPassStatisticDecls(pass, os);
+
+  // Private content
+  os << "private:\n";
+
+  if (pass.getConstructor().empty()) {
+    os << llvm::formatv(friendDefaultConstructorTemplate, passName);
+
+    if (!pass.getOptions().empty())
+      os << llvm::formatv(friendDefaultConstructorWithOptionsTemplate,
+                          passName);
+  }
+
   os << "};\n";
+
+  os << "#undef " << enableVarName << "\n";
+  os << "#endif // " << enableVarName << "\n";
 }
 
-/// Emit the code for registering each of the given passes with the global
-/// PassRegistry.
-static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
-  os << "#ifdef GEN_PASS_CLASSES\n";
+static void emitDefs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
+  std::vector<Pass> passes = getPasses(recordKeeper);
+  os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
+
   for (const Pass &pass : passes)
-    emitPassDecl(pass, os);
-  os << "#undef GEN_PASS_CLASSES\n";
-  os << "#endif // GEN_PASS_CLASSES\n";
+    emitPassDefs(pass, os);
 }
 
-//===----------------------------------------------------------------------===//
-// GEN: Pass registration generation
-//===----------------------------------------------------------------------===//
+// TODO drop old pass declarations
+// The old pass base class is being kept until all the passes have switched to
+// the new decls/defs design.
+const char *const oldPassDeclBegin = R"(
+template <typename DerivedT>
+class {0}Base : public {1} {
+public:
+  using Base = {0}Base;
 
-/// The code snippet used to generate a pass registration.
-///
-/// {0}: The def name of the pass record.
-/// {1}: The pass constructor call.
-const char *const passRegistrationCode = R"(
-//===----------------------------------------------------------------------===//
-// {0} Registration
-//===----------------------------------------------------------------------===//
+  {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
+  {0}Base(const {0}Base &other) : {1}(other) {{}
 
-inline void register{0}Pass() {{
-  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
-    return {1};
-  });
-}
-)";
+  /// Returns the command-line argument attached to this pass.
+  static constexpr ::llvm::StringLiteral getArgumentName() {
+    return ::llvm::StringLiteral("{2}");
+  }
+  ::llvm::StringRef getArgument() const override { return "{2}"; }
 
-/// The code snippet used to generate a function to register all passes in a
-/// group.
-///
-/// {0}: The name of the pass group.
-const char *const passGroupRegistrationCode = R"(
-//===----------------------------------------------------------------------===//
-// {0} Registration
-//===----------------------------------------------------------------------===//
+  ::llvm::StringRef getDescription() const override { return "{3}"; }
 
-inline void register{0}Passes() {{
-)";
+  /// Returns the derived pass name.
+  static constexpr ::llvm::StringLiteral getPassName() {
+    return ::llvm::StringLiteral("{0}");
+  }
+  ::llvm::StringRef getName() const override { return "{0}"; }
 
-/// Emit the code for registering each of the given passes with the global
-/// PassRegistry.
-static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
-  os << "#ifdef GEN_PASS_REGISTRATION\n";
-  for (const Pass &pass : passes) {
-    os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
-                        pass.getConstructor());
+  /// Support isa/dyn_cast functionality for the derived pass class.
+  static bool classof(const ::mlir::Pass *pass) {{
+    return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
   }
 
-  os << llvm::formatv(passGroupRegistrationCode, groupName);
-  for (const Pass &pass : passes)
-    os << "  register" << pass.getDef()->getName() << "Pass();\n";
-  os << "}\n";
-  os << "#undef GEN_PASS_REGISTRATION\n";
-  os << "#endif // GEN_PASS_REGISTRATION\n";
-}
+  /// A clone method to create a copy of this pass.
+  std::unique_ptr<::mlir::Pass> clonePass() const override {{
+    return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
+  }
 
-//===----------------------------------------------------------------------===//
-// GEN: Registration hooks
-//===----------------------------------------------------------------------===//
+  /// Return the dialect that must be loaded in the context before this pass.
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    {4}
+  }
 
-static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
-  os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
-  std::vector<Pass> passes;
-  for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
-    passes.emplace_back(def);
+  /// Explicitly declare the TypeID for this class. We declare an explicit private
+  /// instantiation because Pass classes should only be visible by the current
+  /// library.
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
 
-  emitPassDecls(passes, os);
-  emitRegistration(passes, os);
+protected:
+)";
+
+/// Emit a backward-compatible declaration of the pass base class.
+static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
+  StringRef defName = pass.getDef()->getName();
+  std::string dependentDialectRegistrations;
+  {
+    llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
+    for (StringRef dependentDialect : pass.getDependentDialects())
+      dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+                                  dependentDialect);
+  }
+  os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
+                      pass.getArgument(), pass.getSummary(),
+                      dependentDialectRegistrations);
+  emitPassOptionDecls(pass, os);
+  emitPassStatisticDecls(pass, os);
+  os << "};\n";
 }
 
 static mlir::GenRegistration
-    genRegister("gen-pass-decls", "Generate pass declarations",
+    genPassDecls("gen-pass-decls", "Generate pass declarations",
+                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                   emitDecls(records, os);
+                   return false;
+                 });
+
+static mlir::GenRegistration
+    genPassDefs("gen-pass-defs", "Generate pass definitions",
                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
-                  emitDecls(records, os);
+                  emitDefs(records, os);
                   return false;
                 });

diff  --git a/mlir/unittests/TableGen/CMakeLists.txt b/mlir/unittests/TableGen/CMakeLists.txt
index c51bda6e8d6cc..436c903e3d973 100644
--- a/mlir/unittests/TableGen/CMakeLists.txt
+++ b/mlir/unittests/TableGen/CMakeLists.txt
@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRTableGenEnumsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS passes.td)
 mlir_tablegen(PassGenTest.h.inc -gen-pass-decls -name TableGenTest)
+mlir_tablegen(PassGenTest.cpp.inc -gen-pass-defs -name TableGenTest)
 add_public_tablegen_target(MLIRTableGenTestPassIncGen)
 
 add_mlir_unittest(MLIRTableGenTests

diff  --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp
index 33bd1606051f4..0de960099ae2a 100644
--- a/mlir/unittests/TableGen/PassGenTest.cpp
+++ b/mlir/unittests/TableGen/PassGenTest.cpp
@@ -7,31 +7,36 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
 
 #include "gmock/gmock.h"
 
-std::unique_ptr<mlir::Pass> createTestPass(int v = 0);
+std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0);
 
+#define GEN_PASS_DECL_TESTPASS
+#define GEN_PASS_DECL_TESTPASSWITHOPTIONS
+#define GEN_PASS_DECL_TESTPASSWITHCUSTOMCONSTRUCTOR
 #define GEN_PASS_REGISTRATION
 #include "PassGenTest.h.inc"
 
-#define GEN_PASS_CLASSES
-#include "PassGenTest.h.inc"
+#define GEN_PASS_DEF_TESTPASS
+#define GEN_PASS_DEF_TESTPASSWITHOPTIONS
+#define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR
+#include "PassGenTest.cpp.inc"
 
 struct TestPass : public TestPassBase<TestPass> {
-  explicit TestPass(int v) : extraVal(v) {}
+  using TestPassBase::TestPassBase;
 
   void runOnOperation() override {}
 
   std::unique_ptr<mlir::Pass> clone() const {
     return TestPassBase<TestPass>::clone();
   }
-
-  int extraVal;
 };
 
-std::unique_ptr<mlir::Pass> createTestPass(int v) {
-  return std::make_unique<TestPass>(v);
+TEST(PassGenTest, defaultGeneratedConstructor) {
+  std::unique_ptr<mlir::Pass> pass = createTestPass();
+  EXPECT_TRUE(pass.get() != nullptr);
 }
 
 TEST(PassGenTest, PassClone) {
@@ -41,7 +46,74 @@ TEST(PassGenTest, PassClone) {
     return static_cast<const TestPass *>(pass.get());
   };
 
-  const auto origPass = createTestPass(10);
+  const auto origPass = createTestPass();
+  const auto clonePass = unwrap(origPass)->clone();
+
+  EXPECT_TRUE(clonePass.get() != nullptr);
+  EXPECT_TRUE(origPass.get() != clonePass.get());
+}
+
+struct TestPassWithOptions
+    : public TestPassWithOptionsBase<TestPassWithOptions> {
+  using TestPassWithOptionsBase::TestPassWithOptionsBase;
+
+  void runOnOperation() override {}
+
+  std::unique_ptr<mlir::Pass> clone() const {
+    return TestPassWithOptionsBase<TestPassWithOptions>::clone();
+  }
+
+  unsigned getTestOption() const { return testOption; }
+
+  llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; }
+};
+
+TEST(PassGenTest, PassOptions) {
+  mlir::MLIRContext context;
+
+  TestPassWithOptionsOptions options;
+  options.testOption = 57;
+
+  llvm::SmallVector<int64_t, 2> testListOption = {1, 2};
+  options.testListOption = testListOption;
+
+  const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
+    return static_cast<const TestPassWithOptions *>(pass.get());
+  };
+
+  const auto pass = createTestPassWithOptions(options);
+
+  EXPECT_EQ(unwrap(pass)->getTestOption(), 57);
+  EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1);
+  EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2);
+}
+
+struct TestPassWithCustomConstructor
+    : public TestPassWithCustomConstructorBase<TestPassWithCustomConstructor> {
+  explicit TestPassWithCustomConstructor(int v) : extraVal(v) {}
+
+  void runOnOperation() override {}
+
+  std::unique_ptr<mlir::Pass> clone() const {
+    return TestPassWithCustomConstructorBase<
+        TestPassWithCustomConstructor>::clone();
+  }
+
+  unsigned int extraVal = 23;
+};
+
+std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) {
+  return std::make_unique<TestPassWithCustomConstructor>(v);
+}
+
+TEST(PassGenTest, PassCloneWithCustomConstructor) {
+  mlir::MLIRContext context;
+
+  const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
+    return static_cast<const TestPassWithCustomConstructor *>(pass.get());
+  };
+
+  const auto origPass = createTestPassWithCustomConstructor(10);
   const auto clonePass = unwrap(origPass)->clone();
 
   EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);

diff  --git a/mlir/unittests/TableGen/passes.td b/mlir/unittests/TableGen/passes.td
index f730390ebc8eb..c1ef136a1b7e7 100644
--- a/mlir/unittests/TableGen/passes.td
+++ b/mlir/unittests/TableGen/passes.td
@@ -12,8 +12,20 @@ include "mlir/Rewrite/PassUtil.td"
 
 def TestPass : Pass<"test"> {
   let summary = "Test pass";
+}
+
+def TestPassWithOptions : Pass<"test"> {
+  let summary = "Test pass with options";
+
+  let options = [
+    Option<"testOption", "testOption", "unsigned", "0", "Test option">,
+    ListOption<"testListOption", "test-list-option", "int64_t",
+               "Test list option">
+  ];
+}
 
-  let constructor = "::createTestPass()";
+def TestPassWithCustomConstructor : Pass<"test"> {
+  let summary = "Test pass with custom constructor";
 
-  let options = RewritePassUtils.options;
+  let constructor = "::createTestPassWithCustomConstructor()";
 }


        


More information about the Mlir-commits mailing list