[Mlir-commits] [mlir] c8a3f56 - Decouple registring passes from specifying argument/description

Mehdi Amini llvmlistbot at llvm.org
Wed Jun 16 16:42:05 PDT 2021


Author: Mehdi Amini
Date: 2021-06-16T23:41:50Z
New Revision: c8a3f561ebfd6a5fd6c3efb65944760c7a1a446f

URL: https://github.com/llvm/llvm-project/commit/c8a3f561ebfd6a5fd6c3efb65944760c7a1a446f
DIFF: https://github.com/llvm/llvm-project/commit/c8a3f561ebfd6a5fd6c3efb65944760c7a1a446f.diff

LOG: Decouple registring passes from specifying argument/description

This patch changes the (not recommended) static registration API from:

 static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");

to:

 static PassRegistration<MyPass> reg;

And the explicit registration from:

  void registerPass("my-pass", "My Pass Description.",
                    [] { return createMyPass(); });

To:

  void registerPass([] { return createMyPass(); });

It is expected that Pass implementations overrides the getArgument() method
instead. This will ensure that pipeline description can be printed and parsed
back.

Differential Revision: https://reviews.llvm.org/D104421

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/Pass/Pass.h
    mlir/include/mlir/Pass/PassRegistry.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Pass/PassRegistry.cpp
    mlir/test/Transforms/print-op-graph.mlir
    mlir/test/python/pass_manager.py
    mlir/tools/mlir-tblgen/PassGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 5772ee4b2d745..16e3167e64700 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -86,8 +86,7 @@ struct MyFunctionPass : public PassWrapper<MyFunctionPass,
 /// Register this pass so that it can be built via from a textual pass pipeline.
 /// (Pass registration is discussed more below)
 void registerMyPass() {
-  PassRegistration<MyFunctionPass>(
-    "flag-name-to-invoke-pass-via-mlir-opt", "Pass description here");
+  PassRegistration<MyFunctionPass>();
 }
 ```
 
@@ -503,7 +502,15 @@ struct MyPass ... {
   /// ensure that the options are initialized properly.
   MyPass() = default;
   MyPass(const MyPass& pass) {}
-
+  StringRef getArgument() const final {
+    // This is the argument used to refer to the pass in
+    // the textual format (on the commandline for example).
+    return "argument";
+  }
+  StringRef getDescription() const final {
+    // This is a brief description of the pass.
+    return  "description";
+  }
   /// Define the statistic to track during the execution of MyPass.
   Statistic exampleStat{this, "exampleStat", "An example statistic"};
 
@@ -562,21 +569,22 @@ example registration is shown below:
 
 ```c++
 void registerMyPass() {
-  PassRegistration<MyPass>("argument", "description");
+  PassRegistration<MyPass>();
 }
 ```
 
 *   `MyPass` is the name of the derived pass class.
-*   "argument" is the argument used to refer to the pass in the textual format.
-*   "description" is a brief description of the pass.
+*   The pass `getArgument()` method is used to get the identifier that will be
+    used to refer to the pass.
+*   The pass `getDescription()` method provides a short summary describing the
+    pass.
 
 For passes that cannot be default-constructed, `PassRegistration` accepts an
-optional third argument that takes a callback to create the pass:
+optional argument that takes a callback to create the pass:
 
 ```c++
 void registerMyPass() {
   PassRegistration<MyParametricPass>(
-    "argument", "description",
     []() -> std::unique_ptr<Pass> {
       std::unique_ptr<Pass> p = std::make_unique<MyParametricPass>(/*options*/);
       /*... non-trivial-logic to configure the pass ...*/;
@@ -710,7 +718,7 @@ std::unique_ptr<Pass> foo::createMyPass() {
 
 /// Register this pass.
 void foo::registerMyPass() {
-  PassRegistration<MyPass>("my-pass", "My pass summary");
+  PassRegistration<MyPass>();
 }
 ```
 

diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 67c695467706a..da91ef303cd1d 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -73,10 +73,14 @@ class Pass {
   /// register the Affine dialect but does not need to register Linalg.
   virtual void getDependentDialects(DialectRegistry &registry) const {}
 
-  /// Returns the command line argument used when registering this pass. Return
+  /// Return the command line argument used when registering this pass. Return
   /// an empty string if one does not exist.
   virtual StringRef getArgument() const { return ""; }
 
+  /// Return the command line description used when registering this pass.
+  /// Return an empty string if one does not exist.
+  virtual StringRef getDescription() const { return ""; }
+
   /// Returns the name of the operation that this pass operates on, or None if
   /// this is a generic OperationPass.
   Optional<StringRef> getOpName() const { return opName; }

diff  --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index d03aaf8dfd25f..449889b917812 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -125,20 +125,33 @@ void registerPassPipeline(
 
 /// Register a specific dialect pass allocator function with the system,
 /// typically used through the PassRegistration template.
+/// Deprecated: please use the alternate version below.
 void registerPass(StringRef arg, StringRef description,
                   const PassAllocatorFunction &function);
 
+/// Register a specific dialect pass allocator function with the system,
+/// typically used through the PassRegistration template.
+void registerPass(const PassAllocatorFunction &function);
+
 /// PassRegistration provides a global initializer that registers a Pass
-/// allocation routine for a concrete pass instance. The third argument is
+/// allocation routine for a concrete pass instance. The argument is
 /// optional and provides a callback to construct a pass that does not have
 /// a default constructor.
 ///
 /// Usage:
 ///
 ///   /// At namespace scope.
-///   static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
+///   static PassRegistration<MyPass> reg;
 ///
 template <typename ConcretePass> struct PassRegistration {
+  PassRegistration(const PassAllocatorFunction &constructor) {
+    registerPass(constructor);
+  }
+  PassRegistration()
+      : PassRegistration([] { return std::make_unique<ConcretePass>(); }) {}
+
+  /// Constructor below are deprecated.
+
   PassRegistration(StringRef arg, StringRef description,
                    const PassAllocatorFunction &constructor) {
     registerPass(arg, description, constructor);

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 52a822c26b054..ecd60de1104ed 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -622,11 +622,6 @@ def PrintOpStats : Pass<"print-op-stats"> {
   let constructor = "mlir::createPrintOpStatsPass()";
 }
 
-def PrintOp : Pass<"print-op-graph", "ModuleOp"> {
-  let summary = "Print op graph per-Region";
-  let constructor = "mlir::createPrintOpGraphPass()";
-}
-
 def SCCP : Pass<"sccp"> {
   let summary = "Sparse Conditional Constant Propagation";
   let description = [{

diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 2c690a2659ac3..7f002ac0186ec 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -122,6 +122,15 @@ void mlir::registerPass(StringRef arg, StringRef description,
   }
 }
 
+void mlir::registerPass(const PassAllocatorFunction &function) {
+  std::unique_ptr<Pass> pass = function();
+  StringRef arg = pass->getArgument();
+  if (arg.empty())
+    llvm::report_fatal_error(
+        "Trying to register a pass that does not override `getArgument()`");
+  registerPass(arg, pass->getDescription(), function);
+}
+
 /// Returns the pass info for the specified pass argument or null if unknown.
 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
   auto it = passRegistry->find(passArg);

diff  --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir
index 8ab60508b9607..4a5ac380632e1 100644
--- a/mlir/test/Transforms/print-op-graph.mlir
+++ b/mlir/test/Transforms/print-op-graph.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -print-op-graph %s -o %t 2>&1 | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s
 
 // CHECK-LABEL: digraph "merge_blocks"
 // CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>}

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 35e6d980c9f39..4d6432cbece28 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -71,10 +71,10 @@ def testParseFail():
 def testInvalidNesting():
   with Context():
     try:
-      pm = PassManager.parse("func(print-op-graph)")
+      pm = PassManager.parse("func(view-op-graph)")
     except ValueError as e:
       # CHECK: Can't add pass 'ViewOpGraphPass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
-      # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'.
+      # CHECK: ValueError exception: invalid pass pipeline 'func(view-op-graph)'.
       log("ValueError exception:", e)
     else:
       log("Exception not produced")

diff  --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index a33297e8fdfd7..8f3a19daaa5fb 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -56,6 +56,8 @@ class {0}Base : public {1} {
   }
   ::llvm::StringRef getArgument() const override { return "{2}"; }
 
+  ::llvm::StringRef getDescription() const override { return "{3}"; }
+
   /// Returns the derived pass name.
   static constexpr ::llvm::StringLiteral getPassName() {
     return ::llvm::StringLiteral("{0}");
@@ -74,7 +76,7 @@ class {0}Base : public {1} {
 
   /// Return the dialect that must be loaded in the context before this pass.
   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
-    {3}
+    {4}
   }
 
 protected:
@@ -122,7 +124,8 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
                                   dependentDialect);
   }
   os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
-                      pass.getArgument(), dependentDialectRegistrations);
+                      pass.getArgument(), pass.getSummary(),
+                      dependentDialectRegistrations);
   emitPassOptionDecls(pass, os);
   emitPassStatisticDecls(pass, os);
   os << "};\n";
@@ -154,8 +157,8 @@ const char *const passRegistrationCode = R"(
 //===----------------------------------------------------------------------===//
 
 inline void register{0}Pass() {{
-  ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
-    return {3};
+  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
+    return {1};
   });
 }
 )";
@@ -175,7 +178,6 @@ static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
   os << "#ifdef GEN_PASS_REGISTRATION\n";
   for (const Pass &pass : passes) {
     os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
-                        pass.getArgument(), pass.getSummary(),
                         pass.getConstructor());
   }
 


        


More information about the Mlir-commits mailing list