[Mlir-commits] [mlir] 6f5590c - [mlir][CAPI] Allow running pass manager on any operation

Rahul Kayaith llvmlistbot at llvm.org
Wed Mar 1 15:17:22 PST 2023


Author: rkayaith
Date: 2023-03-01T18:17:13-05:00
New Revision: 6f5590ca347a5a2467b8aaea4b24bc9b70ef138f

URL: https://github.com/llvm/llvm-project/commit/6f5590ca347a5a2467b8aaea4b24bc9b70ef138f
DIFF: https://github.com/llvm/llvm-project/commit/6f5590ca347a5a2467b8aaea4b24bc9b70ef138f.diff

LOG: [mlir][CAPI] Allow running pass manager on any operation

`mlirPassManagerRun` is currently restricted to running on
`builtin.module` ops, but this restriction doesn't exist on the C++
side. This renames it to `mlirPassManagerRunOnOp` and updates it to take
`MlirOperation` instead of `MlirModule`.

Depends on D143352

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D143354

Added: 
    

Modified: 
    mlir/include/mlir-c/Pass.h
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/lib/CAPI/IR/Pass.cpp
    mlir/test/CAPI/execution_engine.c
    mlir/test/CAPI/pass.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 721f1f28fe916..35db138305d1e 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -70,9 +70,9 @@ static inline bool mlirPassManagerIsNull(MlirPassManager passManager) {
 MLIR_CAPI_EXPORTED MlirOpPassManager
 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
 
-/// Run the provided `passManager` on the given `module`.
+/// Run the provided `passManager` on the given `op`.
 MLIR_CAPI_EXPORTED MlirLogicalResult
-mlirPassManagerRun(MlirPassManager passManager, MlirModule module);
+mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
 /// Enable mlir-print-ir-after-all.
 MLIR_CAPI_EXPORTED void

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cb3c1586eb996..99f17a18bc83c 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -117,8 +117,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
       .def(
           "run",
           [](PyPassManager &passManager, PyModule &module) {
-            MlirLogicalResult status =
-                mlirPassManagerRun(passManager.get(), module.get());
+            MlirLogicalResult status = mlirPassManagerRunOnOp(
+                passManager.get(), mlirModuleGetOperation(module.get()));
             if (mlirLogicalResultIsFailure(status))
               throw SetPyError(PyExc_RuntimeError,
                                "Failure while executing pass pipeline.");

diff  --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index b921154111c22..d242baae99c08 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -39,9 +39,9 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
   return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
 }
 
-MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
-                                     MlirModule module) {
-  return wrap(unwrap(passManager)->run(unwrap(module)));
+MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
+                                         MlirOperation op) {
+  return wrap(unwrap(passManager)->run(unwrap(op)));
 }
 
 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {

diff  --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c
index 582b2f8fe2223..2b4d448a65ef6 100644
--- a/mlir/test/CAPI/execution_engine.c
+++ b/mlir/test/CAPI/execution_engine.c
@@ -37,7 +37,8 @@ void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
   mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
   mlirOpPassManagerAddOwnedPass(
       opm, mlirCreateConversionArithToLLVMConversionPass());
-  MlirLogicalResult status = mlirPassManagerRun(pm, module);
+  MlirLogicalResult status =
+      mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
   if (mlirLogicalResultIsFailure(status)) {
     fprintf(stderr, "Unexpected failure running pass pipeline\n");
     exit(2);

diff  --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 42b5ce7080cac..3aad0016b393c 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -33,17 +33,16 @@ void testRunPassOnModule(void) {
   MlirContext ctx = mlirContextCreate();
   registerAllUpstreamDialects(ctx);
 
-  MlirModule module = mlirModuleCreateParse(
-      ctx,
-      // clang-format off
-                            mlirStringRefCreateFromCString(
-"func.func @foo(%arg0 : i32) -> i32 {                                   \n"
-"  %res = arith.addi %arg0, %arg0 : i32                                     \n"
-"  return %res : i32                                                        \n"
-"}"));
-  // clang-format on
-  if (mlirModuleIsNull(module)) {
-    fprintf(stderr, "Unexpected failure parsing module.\n");
+  const char *funcAsm = //
+      "func.func @foo(%arg0 : i32) -> i32 {   \n"
+      "  %res = arith.addi %arg0, %arg0 : i32 \n"
+      "  return %res : i32                    \n"
+      "}                                      \n";
+  MlirOperation func =
+      mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
+                               mlirStringRefCreateFromCString("funcAsm"));
+  if (mlirOperationIsNull(func)) {
+    fprintf(stderr, "Unexpected failure parsing asm.\n");
     exit(EXIT_FAILURE);
   }
 
@@ -56,14 +55,14 @@ void testRunPassOnModule(void) {
     MlirPassManager pm = mlirPassManagerCreate(ctx);
     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
     if (mlirLogicalResultIsFailure(success)) {
       fprintf(stderr, "Unexpected failure running pass manager.\n");
       exit(EXIT_FAILURE);
     }
     mlirPassManagerDestroy(pm);
   }
-  mlirModuleDestroy(module);
+  mlirOperationDestroy(func);
   mlirContextDestroy(ctx);
 }
 
@@ -71,22 +70,23 @@ void testRunPassOnNestedModule(void) {
   MlirContext ctx = mlirContextCreate();
   registerAllUpstreamDialects(ctx);
 
-  MlirModule module = mlirModuleCreateParse(
-      ctx,
-      // clang-format off
-                            mlirStringRefCreateFromCString(
-"func.func @foo(%arg0 : i32) -> i32 {                                   \n"
-"  %res = arith.addi %arg0, %arg0 : i32                                     \n"
-"  return %res : i32                                                        \n"
-"}                                                                          \n"
-"module {                                                                   \n"
-"  func.func @bar(%arg0 : f32) -> f32 {                                     \n"
-"    %res = arith.addf %arg0, %arg0 : f32                                   \n"
-"    return %res : f32                                                      \n"
-"  }                                                                        \n"
-"}"));
-  // clang-format on
-  if (mlirModuleIsNull(module))
+  const char *moduleAsm = //
+      "module {                                   \n"
+      "  func.func @foo(%arg0 : i32) -> i32 {     \n"
+      "    %res = arith.addi %arg0, %arg0 : i32   \n"
+      "    return %res : i32                      \n"
+      "  }                                        \n"
+      "  module {                                 \n"
+      "    func.func @bar(%arg0 : f32) -> f32 {   \n"
+      "      %res = arith.addf %arg0, %arg0 : f32 \n"
+      "      return %res : f32                    \n"
+      "    }                                      \n"
+      "  }                                        \n"
+      "}                                          \n";
+  MlirOperation module =
+      mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
+                               mlirStringRefCreateFromCString("moduleAsm"));
+  if (mlirOperationIsNull(module))
     exit(1);
 
   // Run the print-op-stats pass on functions under the top-level module:
@@ -100,7 +100,7 @@ void testRunPassOnNestedModule(void) {
         pm, mlirStringRefCreateFromCString("func.func"));
     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success))
       exit(2);
     mlirPassManagerDestroy(pm);
@@ -118,13 +118,13 @@ void testRunPassOnNestedModule(void) {
         nestedModulePm, mlirStringRefCreateFromCString("func.func"));
     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success))
       exit(2);
     mlirPassManagerDestroy(pm);
   }
 
-  mlirModuleDestroy(module);
+  mlirOperationDestroy(module);
   mlirContextDestroy(ctx);
 }
 
@@ -339,16 +339,17 @@ void testExternalPass(void) {
   MlirContext ctx = mlirContextCreate();
   registerAllUpstreamDialects(ctx);
 
-  MlirModule module = mlirModuleCreateParse(
-      ctx,
-      // clang-format off
-      mlirStringRefCreateFromCString(
-"func.func @foo(%arg0 : i32) -> i32 {                                   \n"
-"  %res = arith.addi %arg0, %arg0 : i32                                     \n"
-"  return %res : i32                                                        \n"
-"}"));
-  // clang-format on
-  if (mlirModuleIsNull(module)) {
+  const char *moduleAsm = //
+      "module {                                 \n"
+      "  func.func @foo(%arg0 : i32) -> i32 {   \n"
+      "    %res = arith.addi %arg0, %arg0 : i32 \n"
+      "    return %res : i32                    \n"
+      "  }                                      \n"
+      "}";
+  MlirOperation module =
+      mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
+                               mlirStringRefCreateFromCString("moduleAsm"));
+  if (mlirOperationIsNull(module)) {
     fprintf(stderr, "Unexpected failure parsing module.\n");
     exit(EXIT_FAILURE);
   }
@@ -377,7 +378,7 @@ void testExternalPass(void) {
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
     mlirPassManagerAddOwnedPass(pm, externalPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success)) {
       fprintf(stderr, "Unexpected failure running external pass.\n");
       exit(EXIT_FAILURE);
@@ -421,7 +422,7 @@ void testExternalPass(void) {
     MlirOpPassManager nestedFuncPm =
         mlirPassManagerGetNestedUnder(pm, funcOpName);
     mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success)) {
       fprintf(stderr, "Unexpected failure running external operation pass.\n");
       exit(EXIT_FAILURE);
@@ -469,7 +470,7 @@ void testExternalPass(void) {
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
     mlirPassManagerAddOwnedPass(pm, externalPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success)) {
       fprintf(stderr, "Unexpected failure running external pass.\n");
       exit(EXIT_FAILURE);
@@ -516,7 +517,7 @@ void testExternalPass(void) {
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
     mlirPassManagerAddOwnedPass(pm, externalPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsSuccess(success)) {
       fprintf(
           stderr,
@@ -564,7 +565,7 @@ void testExternalPass(void) {
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
     mlirPassManagerAddOwnedPass(pm, externalPass);
-    MlirLogicalResult success = mlirPassManagerRun(pm, module);
+    MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsSuccess(success)) {
       fprintf(
           stderr,
@@ -587,7 +588,7 @@ void testExternalPass(void) {
   }
 
   mlirTypeIDAllocatorDestroy(typeIDAllocator);
-  mlirModuleDestroy(module);
+  mlirOperationDestroy(module);
   mlirContextDestroy(ctx);
 }
 


        


More information about the Mlir-commits mailing list