[Mlir-commits] [mlir] c8a3f56 - Decouple registring passes from specifying argument/description
Mehdi Amini
llvmlistbot at llvm.org
Wed Jun 16 16:42:05 PDT 2021
Author: Mehdi Amini
Date: 2021-06-16T23:41:50Z
New Revision: c8a3f561ebfd6a5fd6c3efb65944760c7a1a446f
URL: https://github.com/llvm/llvm-project/commit/c8a3f561ebfd6a5fd6c3efb65944760c7a1a446f
DIFF: https://github.com/llvm/llvm-project/commit/c8a3f561ebfd6a5fd6c3efb65944760c7a1a446f.diff
LOG: Decouple registring passes from specifying argument/description
This patch changes the (not recommended) static registration API from:
static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
to:
static PassRegistration<MyPass> reg;
And the explicit registration from:
void registerPass("my-pass", "My Pass Description.",
[] { return createMyPass(); });
To:
void registerPass([] { return createMyPass(); });
It is expected that Pass implementations overrides the getArgument() method
instead. This will ensure that pipeline description can be printed and parsed
back.
Differential Revision: https://reviews.llvm.org/D104421
Added:
Modified:
mlir/docs/PassManagement.md
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassRegistry.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Pass/PassRegistry.cpp
mlir/test/Transforms/print-op-graph.mlir
mlir/test/python/pass_manager.py
mlir/tools/mlir-tblgen/PassGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 5772ee4b2d745..16e3167e64700 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -86,8 +86,7 @@ struct MyFunctionPass : public PassWrapper<MyFunctionPass,
/// Register this pass so that it can be built via from a textual pass pipeline.
/// (Pass registration is discussed more below)
void registerMyPass() {
- PassRegistration<MyFunctionPass>(
- "flag-name-to-invoke-pass-via-mlir-opt", "Pass description here");
+ PassRegistration<MyFunctionPass>();
}
```
@@ -503,7 +502,15 @@ struct MyPass ... {
/// ensure that the options are initialized properly.
MyPass() = default;
MyPass(const MyPass& pass) {}
-
+ StringRef getArgument() const final {
+ // This is the argument used to refer to the pass in
+ // the textual format (on the commandline for example).
+ return "argument";
+ }
+ StringRef getDescription() const final {
+ // This is a brief description of the pass.
+ return "description";
+ }
/// Define the statistic to track during the execution of MyPass.
Statistic exampleStat{this, "exampleStat", "An example statistic"};
@@ -562,21 +569,22 @@ example registration is shown below:
```c++
void registerMyPass() {
- PassRegistration<MyPass>("argument", "description");
+ PassRegistration<MyPass>();
}
```
* `MyPass` is the name of the derived pass class.
-* "argument" is the argument used to refer to the pass in the textual format.
-* "description" is a brief description of the pass.
+* The pass `getArgument()` method is used to get the identifier that will be
+ used to refer to the pass.
+* The pass `getDescription()` method provides a short summary describing the
+ pass.
For passes that cannot be default-constructed, `PassRegistration` accepts an
-optional third argument that takes a callback to create the pass:
+optional argument that takes a callback to create the pass:
```c++
void registerMyPass() {
PassRegistration<MyParametricPass>(
- "argument", "description",
[]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> p = std::make_unique<MyParametricPass>(/*options*/);
/*... non-trivial-logic to configure the pass ...*/;
@@ -710,7 +718,7 @@ std::unique_ptr<Pass> foo::createMyPass() {
/// Register this pass.
void foo::registerMyPass() {
- PassRegistration<MyPass>("my-pass", "My pass summary");
+ PassRegistration<MyPass>();
}
```
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 67c695467706a..da91ef303cd1d 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -73,10 +73,14 @@ class Pass {
/// register the Affine dialect but does not need to register Linalg.
virtual void getDependentDialects(DialectRegistry ®istry) const {}
- /// Returns the command line argument used when registering this pass. Return
+ /// Return the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
virtual StringRef getArgument() const { return ""; }
+ /// Return the command line description used when registering this pass.
+ /// Return an empty string if one does not exist.
+ virtual StringRef getDescription() const { return ""; }
+
/// Returns the name of the operation that this pass operates on, or None if
/// this is a generic OperationPass.
Optional<StringRef> getOpName() const { return opName; }
diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index d03aaf8dfd25f..449889b917812 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -125,20 +125,33 @@ void registerPassPipeline(
/// Register a specific dialect pass allocator function with the system,
/// typically used through the PassRegistration template.
+/// Deprecated: please use the alternate version below.
void registerPass(StringRef arg, StringRef description,
const PassAllocatorFunction &function);
+/// Register a specific dialect pass allocator function with the system,
+/// typically used through the PassRegistration template.
+void registerPass(const PassAllocatorFunction &function);
+
/// PassRegistration provides a global initializer that registers a Pass
-/// allocation routine for a concrete pass instance. The third argument is
+/// allocation routine for a concrete pass instance. The argument is
/// optional and provides a callback to construct a pass that does not have
/// a default constructor.
///
/// Usage:
///
/// /// At namespace scope.
-/// static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
+/// static PassRegistration<MyPass> reg;
///
template <typename ConcretePass> struct PassRegistration {
+ PassRegistration(const PassAllocatorFunction &constructor) {
+ registerPass(constructor);
+ }
+ PassRegistration()
+ : PassRegistration([] { return std::make_unique<ConcretePass>(); }) {}
+
+ /// Constructor below are deprecated.
+
PassRegistration(StringRef arg, StringRef description,
const PassAllocatorFunction &constructor) {
registerPass(arg, description, constructor);
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 52a822c26b054..ecd60de1104ed 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -622,11 +622,6 @@ def PrintOpStats : Pass<"print-op-stats"> {
let constructor = "mlir::createPrintOpStatsPass()";
}
-def PrintOp : Pass<"print-op-graph", "ModuleOp"> {
- let summary = "Print op graph per-Region";
- let constructor = "mlir::createPrintOpGraphPass()";
-}
-
def SCCP : Pass<"sccp"> {
let summary = "Sparse Conditional Constant Propagation";
let description = [{
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 2c690a2659ac3..7f002ac0186ec 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -122,6 +122,15 @@ void mlir::registerPass(StringRef arg, StringRef description,
}
}
+void mlir::registerPass(const PassAllocatorFunction &function) {
+ std::unique_ptr<Pass> pass = function();
+ StringRef arg = pass->getArgument();
+ if (arg.empty())
+ llvm::report_fatal_error(
+ "Trying to register a pass that does not override `getArgument()`");
+ registerPass(arg, pass->getDescription(), function);
+}
+
/// Returns the pass info for the specified pass argument or null if unknown.
const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
auto it = passRegistry->find(passArg);
diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir
index 8ab60508b9607..4a5ac380632e1 100644
--- a/mlir/test/Transforms/print-op-graph.mlir
+++ b/mlir/test/Transforms/print-op-graph.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -print-op-graph %s -o %t 2>&1 | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s
// CHECK-LABEL: digraph "merge_blocks"
// CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>}
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 35e6d980c9f39..4d6432cbece28 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -71,10 +71,10 @@ def testParseFail():
def testInvalidNesting():
with Context():
try:
- pm = PassManager.parse("func(print-op-graph)")
+ pm = PassManager.parse("func(view-op-graph)")
except ValueError as e:
# CHECK: Can't add pass 'ViewOpGraphPass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
- # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'.
+ # CHECK: ValueError exception: invalid pass pipeline 'func(view-op-graph)'.
log("ValueError exception:", e)
else:
log("Exception not produced")
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index a33297e8fdfd7..8f3a19daaa5fb 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -56,6 +56,8 @@ class {0}Base : public {1} {
}
::llvm::StringRef getArgument() const override { return "{2}"; }
+ ::llvm::StringRef getDescription() const override { return "{3}"; }
+
/// Returns the derived pass name.
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("{0}");
@@ -74,7 +76,7 @@ class {0}Base : public {1} {
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
- {3}
+ {4}
}
protected:
@@ -122,7 +124,8 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
dependentDialect);
}
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
- pass.getArgument(), dependentDialectRegistrations);
+ pass.getArgument(), pass.getSummary(),
+ dependentDialectRegistrations);
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";
@@ -154,8 +157,8 @@ const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
inline void register{0}Pass() {{
- ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
- return {3};
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
+ return {1};
});
}
)";
@@ -175,7 +178,6 @@ static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
- pass.getArgument(), pass.getSummary(),
pass.getConstructor());
}
More information about the Mlir-commits
mailing list