[Mlir-commits] [mlir] aeb4b1a - Add facilities to print/parse a pass pipeline through the C API
Mehdi Amini
llvmlistbot at llvm.org
Wed Nov 4 09:30:15 PST 2020
Author: Mehdi Amini
Date: 2020-11-04T17:29:49Z
New Revision: aeb4b1a9d8c9f9c4a4530cd3f2394b52c2187d51
URL: https://github.com/llvm/llvm-project/commit/aeb4b1a9d8c9f9c4a4530cd3f2394b52c2187d51
DIFF: https://github.com/llvm/llvm-project/commit/aeb4b1a9d8c9f9c4a4530cd3f2394b52c2187d51.diff
LOG: Add facilities to print/parse a pass pipeline through the C API
This also includes and exercise a register function for individual
passes.
Differential Revision: https://reviews.llvm.org/D90728
Added:
Modified:
mlir/include/mlir-c/Pass.h
mlir/lib/CAPI/IR/Pass.cpp
mlir/test/CAPI/pass.c
mlir/tools/mlir-tblgen/PassCAPIGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 4e5a4abe87bc..d95e7ccb74e3 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -53,6 +53,10 @@ MlirPassManager mlirPassManagerCreate(MlirContext ctx);
/** Destroy the provided PassManager. */
void mlirPassManagerDestroy(MlirPassManager passManager);
+/** Cast a top-level PassManager to a generic OpPassManager. */
+MlirOpPassManager
+mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
+
/** Run the provided `passManager` on the given `module`. */
MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
MlirModule module);
@@ -83,6 +87,17 @@ void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass);
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
MlirPass pass);
+/** Print a textual MLIR pass pipeline by sending chunks of the string
+ * representation and forwarding `userData to `callback`. Note that the callback
+ * may be called several times with consecutive chunks of the string. */
+void mlirPrintPassPipeline(MlirOpPassManager passManager,
+ MlirStringCallback callback, void *userData);
+
+/** Parse a textual MLIR pass pipeline and add it to the provided OpPassManager.
+ */
+MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
+ MlirStringRef pipeline);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 7d35875d28a6..c67cd474abbe 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -28,6 +28,11 @@ void mlirPassManagerDestroy(MlirPassManager passManager) {
delete unwrap(passManager);
}
+MlirOpPassManager
+mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
+ return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
+}
+
MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
MlirModule module) {
return wrap(unwrap(passManager)->run(unwrap(module)));
@@ -51,3 +56,16 @@ void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
MlirPass pass) {
unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
}
+
+void mlirPrintPassPipeline(MlirOpPassManager passManager,
+ MlirStringCallback callback, void *userData) {
+ detail::CallbackOstream stream(callback, userData);
+ unwrap(passManager)->printAsTextualPipeline(stream);
+}
+
+MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
+ MlirStringRef pipeline) {
+ // TODO: errors are sent to std::errs() at the moment, we should pass in a
+ // stream and redirect to a diagnostic.
+ return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
+}
diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index ff67eda68cc9..686c7ade44f6 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -33,8 +33,10 @@ void testRunPassOnModule() {
" return %res : i32 \n"
"}");
// clang-format on
- if (mlirModuleIsNull(module))
+ if (mlirModuleIsNull(module)) {
+ fprintf(stderr, "Unexpected failure parsing module.\n");
exit(EXIT_FAILURE);
+ }
// Run the print-op-stats pass on the top-level module:
// CHECK-LABEL: Operations encountered:
@@ -47,8 +49,10 @@ void testRunPassOnModule() {
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirPassManagerAddOwnedPass(pm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
- if (mlirLogicalResultIsFailure(success))
+ if (mlirLogicalResultIsFailure(success)) {
+ fprintf(stderr, "Unexpected failure running pass manager.\n");
exit(EXIT_FAILURE);
+ }
mlirPassManagerDestroy(pm);
}
mlirModuleDestroy(module);
@@ -117,8 +121,81 @@ void testRunPassOnNestedModule() {
mlirContextDestroy(ctx);
}
+static void printToStderr(const char *str, intptr_t len, void *userData) {
+ (void)userData;
+ fwrite(str, 1, len, stderr);
+}
+
+void testPrintPassPipeline() {
+ MlirContext ctx = mlirContextCreate();
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ // Populate the pass-manager
+ MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
+ pm, mlirStringRefCreateFromCString("module"));
+ MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
+ nestedModulePm, mlirStringRefCreateFromCString("func"));
+ MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
+ mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
+
+ // Print the top level pass manager
+ // CHECK: Top-level: module(func(print-op-stats))
+ fprintf(stderr, "Top-level: ");
+ mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
+ NULL);
+ fprintf(stderr, "\n");
+
+ // Print the pipeline nested one level down
+ // CHECK: Nested Module: func(print-op-stats)
+ fprintf(stderr, "Nested Module: ");
+ mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ // Print the pipeline nested two levels down
+ // CHECK: Nested Module>Func: print-op-stats
+ fprintf(stderr, "Nested Module>Func: ");
+ mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ mlirPassManagerDestroy(pm);
+ mlirContextDestroy(ctx);
+}
+
+void testParsePassPipeline() {
+ MlirContext ctx = mlirContextCreate();
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ // Try parse a pipeline.
+ MlirLogicalResult status = mlirParsePassPipeline(
+ mlirPassManagerGetAsOpPassManager(pm),
+ mlirStringRefCreateFromCString(
+ "module(func(print-op-stats), func(print-op-stats))"));
+ // Expect a failure, we haven't registered the print-op-stats pass yet.
+ if (mlirLogicalResultIsSuccess(status)) {
+ fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
+ exit(EXIT_FAILURE);
+ }
+ // Try again after registrating the pass.
+ mlirRegisterTransformsPrintOpStats();
+ status = mlirParsePassPipeline(
+ mlirPassManagerGetAsOpPassManager(pm),
+ mlirStringRefCreateFromCString(
+ "module(func(print-op-stats), func(print-op-stats))"));
+ // Expect a failure, we haven't registered the print-op-stats pass yet.
+ if (mlirLogicalResultIsFailure(status)) {
+ fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // CHECK: Round-trip: module(func(print-op-stats), func(print-op-stats))
+ fprintf(stderr, "Round-trip: ");
+ mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
+ NULL);
+ fprintf(stderr, "\n");
+}
+
int main() {
testRunPassOnModule();
testRunPassOnNestedModule();
+ testPrintPassPipeline();
+ testParsePassPipeline();
return 0;
}
diff --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
index 42507c5da7e3..d15bcb4632d2 100644
--- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
@@ -33,6 +33,7 @@ static llvm::cl::opt<std::string>
const char *const passDecl = R"(
/* Create {0} Pass. */
MlirPass mlirCreate{0}{1}();
+void mlirRegister{0}{1}();
)";
@@ -70,6 +71,9 @@ const char *const passCreateDef = R"(
MlirPass mlirCreate{0}{1}() {
return wrap({2}.release());
}
+void mlirRegister{0}{1}() {
+ register{1}Pass();
+}
)";
More information about the Mlir-commits
mailing list