[Mlir-commits] [mlir] 215eba4 - [mlir][CAPI] Include anchor op in mlirParsePassPipeline

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 3 08:48:26 PDT 2022


Author: rkayaith
Date: 2022-11-03T11:48:21-04:00
New Revision: 215eba4e1ea240dd9223c14b80da664f0bb930cc

URL: https://github.com/llvm/llvm-project/commit/215eba4e1ea240dd9223c14b80da664f0bb930cc
DIFF: https://github.com/llvm/llvm-project/commit/215eba4e1ea240dd9223c14b80da664f0bb930cc.diff

LOG: [mlir][CAPI] Include anchor op in mlirParsePassPipeline

The pipeline string must now include the pass manager's anchor op. This
makes the parse API properly roundtrip the printed form of a pass
manager. Since this is already an API break, I also added an extra
callback argument which is used for reporting errors.

The old functionality of appending to an existing pass manager is
available through `mlirOpPassManagerAddPipeline`.

Reviewed By: mehdi_amini, ftynse

Differential Revision: https://reviews.llvm.org/D136403

Added: 
    

Modified: 
    mlir/include/mlir-c/Pass.h
    mlir/lib/CAPI/IR/Pass.cpp
    mlir/test/CAPI/pass.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 704121a0cb096..721f1f28fe916 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -123,10 +123,12 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager,
                                               MlirStringCallback callback,
                                               void *userData);
 
-/// Parse a textual MLIR pass pipeline and add it to the provided OpPassManager.
-
+/// Parse a textual MLIR pass pipeline and assign it to the provided
+/// OpPassManager. If parsing fails an error message is reported using the
+/// provided callback.
 MLIR_CAPI_EXPORTED MlirLogicalResult
-mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline);
+mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline,
+                      MlirStringCallback callback, void *userData);
 
 //===----------------------------------------------------------------------===//
 // External Pass API.

diff  --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 30f5804876940..4afc668592bd8 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -86,10 +86,14 @@ void mlirPrintPassPipeline(MlirOpPassManager passManager,
 }
 
 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
-                                        MlirStringRef pipeline) {
-  // TODO: errors are sent to std::errs() at the moment, we should pass in a
-  // stream and redirect to a diagnostic.
-  return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
+                                        MlirStringRef pipeline,
+                                        MlirStringCallback callback,
+                                        void *userData) {
+  detail::CallbackOstream stream(callback, userData);
+  FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
+  if (succeeded(pm))
+    *unwrap(passManager) = std::move(*pm);
+  return wrap(pm);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 5b04d749b1cdc..87430b9e47978 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -182,7 +182,8 @@ void testParsePassPipeline() {
   MlirLogicalResult status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
       mlirStringRefCreateFromCString(
-          "builtin.module(func.func(print-op-stats{json=false}))"));
+          "builtin.module(func.func(print-op-stats{json=false}))"),
+      printToStderr, NULL);
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsSuccess(status)) {
     fprintf(
@@ -195,7 +196,8 @@ void testParsePassPipeline() {
   status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
       mlirStringRefCreateFromCString(
-          "builtin.module(func.func(print-op-stats{json=false}))"));
+          "builtin.module(func.func(print-op-stats{json=false}))"),
+      printToStderr, NULL);
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsFailure(status)) {
     fprintf(stderr,
@@ -203,9 +205,7 @@ void testParsePassPipeline() {
     exit(EXIT_FAILURE);
   }
 
-  //      CHECK: Round-trip: builtin.module(
-  // CHECK-SAME:   builtin.module(func.func(print-op-stats{json=false}))
-  // CHECK-SAME: )
+  // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}))
   fprintf(stderr, "Round-trip: ");
   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
                         NULL);
@@ -221,7 +221,7 @@ void testParsePassPipeline() {
     exit(EXIT_FAILURE);
   }
   //      CHECK: Appended: builtin.module(
-  // CHECK-SAME:   builtin.module(func.func(print-op-stats{json=false})),
+  // CHECK-SAME:   func.func(print-op-stats{json=false}),
   // CHECK-SAME:   func.func(print-op-stats{json=false})
   // CHECK-SAME: )
   fprintf(stderr, "Appended: ");
@@ -242,6 +242,14 @@ void testParseErrorCapture() {
   MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm);
   MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid");
 
+  // CHECK: mlirParsePassPipeline:
+  // CHECK: expected pass pipeline to be wrapped with the anchor operation type
+  fprintf(stderr, "mlirParsePassPipeline:\n");
+  if (mlirLogicalResultIsSuccess(
+          mlirParsePassPipeline(opm, invalidPipeline, printToStderr, NULL)))
+    exit(EXIT_FAILURE);
+  fprintf(stderr, "\n");
+
   // CHECK: mlirOpPassManagerAddPipeline:
   // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
   fprintf(stderr, "mlirOpPassManagerAddPipeline:\n");
@@ -253,6 +261,9 @@ void testParseErrorCapture() {
   // Make sure all output is going through the callback.
   // CHECK: dontPrint: <>
   fprintf(stderr, "dontPrint: <");
+  if (mlirLogicalResultIsSuccess(
+          mlirParsePassPipeline(opm, invalidPipeline, dontPrint, NULL)))
+    exit(EXIT_FAILURE);
   if (mlirLogicalResultIsSuccess(
           mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL)))
     exit(EXIT_FAILURE);


        


More information about the Mlir-commits mailing list