[Mlir-commits] [mlir] aeb4b1a - Add facilities to print/parse a pass pipeline through the C API

Mehdi Amini llvmlistbot at llvm.org
Wed Nov 4 09:30:15 PST 2020


Author: Mehdi Amini
Date: 2020-11-04T17:29:49Z
New Revision: aeb4b1a9d8c9f9c4a4530cd3f2394b52c2187d51

URL: https://github.com/llvm/llvm-project/commit/aeb4b1a9d8c9f9c4a4530cd3f2394b52c2187d51
DIFF: https://github.com/llvm/llvm-project/commit/aeb4b1a9d8c9f9c4a4530cd3f2394b52c2187d51.diff

LOG: Add facilities to print/parse a pass pipeline through the C API

This also includes and exercise a register function for individual
passes.

Differential Revision: https://reviews.llvm.org/D90728

Added: 
    

Modified: 
    mlir/include/mlir-c/Pass.h
    mlir/lib/CAPI/IR/Pass.cpp
    mlir/test/CAPI/pass.c
    mlir/tools/mlir-tblgen/PassCAPIGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 4e5a4abe87bc..d95e7ccb74e3 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -53,6 +53,10 @@ MlirPassManager mlirPassManagerCreate(MlirContext ctx);
 /** Destroy the provided PassManager. */
 void mlirPassManagerDestroy(MlirPassManager passManager);
 
+/** Cast a top-level PassManager to a generic OpPassManager. */
+MlirOpPassManager
+mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
+
 /** Run the provided `passManager` on the given `module`. */
 MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
                                      MlirModule module);
@@ -83,6 +87,17 @@ void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass);
 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
                                    MlirPass pass);
 
+/** Print a textual MLIR pass pipeline by sending chunks of the string
+ * representation and forwarding `userData to `callback`. Note that the callback
+ * may be called several times with consecutive chunks of the string. */
+void mlirPrintPassPipeline(MlirOpPassManager passManager,
+                           MlirStringCallback callback, void *userData);
+
+/** Parse a textual MLIR pass pipeline and add it to the provided OpPassManager.
+ */
+MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
+                                        MlirStringRef pipeline);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 7d35875d28a6..c67cd474abbe 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -28,6 +28,11 @@ void mlirPassManagerDestroy(MlirPassManager passManager) {
   delete unwrap(passManager);
 }
 
+MlirOpPassManager
+mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
+  return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
+}
+
 MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
                                      MlirModule module) {
   return wrap(unwrap(passManager)->run(unwrap(module)));
@@ -51,3 +56,16 @@ void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
                                    MlirPass pass) {
   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
 }
+
+void mlirPrintPassPipeline(MlirOpPassManager passManager,
+                           MlirStringCallback callback, void *userData) {
+  detail::CallbackOstream stream(callback, userData);
+  unwrap(passManager)->printAsTextualPipeline(stream);
+}
+
+MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
+                                        MlirStringRef pipeline) {
+  // TODO: errors are sent to std::errs() at the moment, we should pass in a
+  // stream and redirect to a diagnostic.
+  return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
+}

diff  --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index ff67eda68cc9..686c7ade44f6 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -33,8 +33,10 @@ void testRunPassOnModule() {
 "  return %res : i32                                                        \n"
 "}");
   // clang-format on
-  if (mlirModuleIsNull(module))
+  if (mlirModuleIsNull(module)) {
+    fprintf(stderr, "Unexpected failure parsing module.\n");
     exit(EXIT_FAILURE);
+  }
 
   // Run the print-op-stats pass on the top-level module:
   // CHECK-LABEL: Operations encountered:
@@ -47,8 +49,10 @@ void testRunPassOnModule() {
     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
     MlirLogicalResult success = mlirPassManagerRun(pm, module);
-    if (mlirLogicalResultIsFailure(success))
+    if (mlirLogicalResultIsFailure(success)) {
+      fprintf(stderr, "Unexpected failure running pass manager.\n");
       exit(EXIT_FAILURE);
+    }
     mlirPassManagerDestroy(pm);
   }
   mlirModuleDestroy(module);
@@ -117,8 +121,81 @@ void testRunPassOnNestedModule() {
   mlirContextDestroy(ctx);
 }
 
+static void printToStderr(const char *str, intptr_t len, void *userData) {
+  (void)userData;
+  fwrite(str, 1, len, stderr);
+}
+
+void testPrintPassPipeline() {
+  MlirContext ctx = mlirContextCreate();
+  MlirPassManager pm = mlirPassManagerCreate(ctx);
+  // Populate the pass-manager
+  MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
+      pm, mlirStringRefCreateFromCString("module"));
+  MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
+      nestedModulePm, mlirStringRefCreateFromCString("func"));
+  MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
+  mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
+
+  // Print the top level pass manager
+  // CHECK: Top-level: module(func(print-op-stats))
+  fprintf(stderr, "Top-level: ");
+  mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
+                        NULL);
+  fprintf(stderr, "\n");
+
+  // Print the pipeline nested one level down
+  // CHECK: Nested Module: func(print-op-stats)
+  fprintf(stderr, "Nested Module: ");
+  mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  // Print the pipeline nested two levels down
+  // CHECK: Nested Module>Func: print-op-stats
+  fprintf(stderr, "Nested Module>Func: ");
+  mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  mlirPassManagerDestroy(pm);
+  mlirContextDestroy(ctx);
+}
+
+void testParsePassPipeline() {
+  MlirContext ctx = mlirContextCreate();
+  MlirPassManager pm = mlirPassManagerCreate(ctx);
+  // Try parse a pipeline.
+  MlirLogicalResult status = mlirParsePassPipeline(
+      mlirPassManagerGetAsOpPassManager(pm),
+      mlirStringRefCreateFromCString(
+          "module(func(print-op-stats), func(print-op-stats))"));
+  // Expect a failure, we haven't registered the print-op-stats pass yet.
+  if (mlirLogicalResultIsSuccess(status)) {
+    fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
+    exit(EXIT_FAILURE);
+  }
+  // Try again after registrating the pass.
+  mlirRegisterTransformsPrintOpStats();
+  status = mlirParsePassPipeline(
+      mlirPassManagerGetAsOpPassManager(pm),
+      mlirStringRefCreateFromCString(
+          "module(func(print-op-stats), func(print-op-stats))"));
+  // Expect a failure, we haven't registered the print-op-stats pass yet.
+  if (mlirLogicalResultIsFailure(status)) {
+    fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
+    exit(EXIT_FAILURE);
+  }
+
+  // CHECK: Round-trip: module(func(print-op-stats), func(print-op-stats))
+  fprintf(stderr, "Round-trip: ");
+  mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
+                        NULL);
+  fprintf(stderr, "\n");
+}
+
 int main() {
   testRunPassOnModule();
   testRunPassOnNestedModule();
+  testPrintPassPipeline();
+  testParsePassPipeline();
   return 0;
 }

diff  --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
index 42507c5da7e3..d15bcb4632d2 100644
--- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
@@ -33,6 +33,7 @@ static llvm::cl::opt<std::string>
 const char *const passDecl = R"(
 /* Create {0} Pass. */
 MlirPass mlirCreate{0}{1}();
+void mlirRegister{0}{1}();
 
 )";
 
@@ -70,6 +71,9 @@ const char *const passCreateDef = R"(
 MlirPass mlirCreate{0}{1}() {
   return wrap({2}.release());
 }
+void mlirRegister{0}{1}() {
+  register{1}Pass();
+}
 
 )";
 


        


More information about the Mlir-commits mailing list