[Mlir-commits] [mlir] 6f5590c - [mlir][CAPI] Allow running pass manager on any operation
Rahul Kayaith
llvmlistbot at llvm.org
Wed Mar 1 15:17:22 PST 2023
Author: rkayaith
Date: 2023-03-01T18:17:13-05:00
New Revision: 6f5590ca347a5a2467b8aaea4b24bc9b70ef138f
URL: https://github.com/llvm/llvm-project/commit/6f5590ca347a5a2467b8aaea4b24bc9b70ef138f
DIFF: https://github.com/llvm/llvm-project/commit/6f5590ca347a5a2467b8aaea4b24bc9b70ef138f.diff
LOG: [mlir][CAPI] Allow running pass manager on any operation
`mlirPassManagerRun` is currently restricted to running on
`builtin.module` ops, but this restriction doesn't exist on the C++
side. This renames it to `mlirPassManagerRunOnOp` and updates it to take
`MlirOperation` instead of `MlirModule`.
Depends on D143352
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D143354
Added:
Modified:
mlir/include/mlir-c/Pass.h
mlir/lib/Bindings/Python/Pass.cpp
mlir/lib/CAPI/IR/Pass.cpp
mlir/test/CAPI/execution_engine.c
mlir/test/CAPI/pass.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 721f1f28fe916..35db138305d1e 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -70,9 +70,9 @@ static inline bool mlirPassManagerIsNull(MlirPassManager passManager) {
MLIR_CAPI_EXPORTED MlirOpPassManager
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
-/// Run the provided `passManager` on the given `module`.
+/// Run the provided `passManager` on the given `op`.
MLIR_CAPI_EXPORTED MlirLogicalResult
-mlirPassManagerRun(MlirPassManager passManager, MlirModule module);
+mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
/// Enable mlir-print-ir-after-all.
MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cb3c1586eb996..99f17a18bc83c 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -117,8 +117,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
.def(
"run",
[](PyPassManager &passManager, PyModule &module) {
- MlirLogicalResult status =
- mlirPassManagerRun(passManager.get(), module.get());
+ MlirLogicalResult status = mlirPassManagerRunOnOp(
+ passManager.get(), mlirModuleGetOperation(module.get()));
if (mlirLogicalResultIsFailure(status))
throw SetPyError(PyExc_RuntimeError,
"Failure while executing pass pipeline.");
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index b921154111c22..d242baae99c08 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -39,9 +39,9 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
}
-MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
- MlirModule module) {
- return wrap(unwrap(passManager)->run(unwrap(module)));
+MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
+ MlirOperation op) {
+ return wrap(unwrap(passManager)->run(unwrap(op)));
}
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c
index 582b2f8fe2223..2b4d448a65ef6 100644
--- a/mlir/test/CAPI/execution_engine.c
+++ b/mlir/test/CAPI/execution_engine.c
@@ -37,7 +37,8 @@ void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
mlirOpPassManagerAddOwnedPass(
opm, mlirCreateConversionArithToLLVMConversionPass());
- MlirLogicalResult status = mlirPassManagerRun(pm, module);
+ MlirLogicalResult status =
+ mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr, "Unexpected failure running pass pipeline\n");
exit(2);
diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 42b5ce7080cac..3aad0016b393c 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -33,17 +33,16 @@ void testRunPassOnModule(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
- MlirModule module = mlirModuleCreateParse(
- ctx,
- // clang-format off
- mlirStringRefCreateFromCString(
-"func.func @foo(%arg0 : i32) -> i32 { \n"
-" %res = arith.addi %arg0, %arg0 : i32 \n"
-" return %res : i32 \n"
-"}"));
- // clang-format on
- if (mlirModuleIsNull(module)) {
- fprintf(stderr, "Unexpected failure parsing module.\n");
+ const char *funcAsm = //
+ "func.func @foo(%arg0 : i32) -> i32 { \n"
+ " %res = arith.addi %arg0, %arg0 : i32 \n"
+ " return %res : i32 \n"
+ "} \n";
+ MlirOperation func =
+ mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
+ mlirStringRefCreateFromCString("funcAsm"));
+ if (mlirOperationIsNull(func)) {
+ fprintf(stderr, "Unexpected failure parsing asm.\n");
exit(EXIT_FAILURE);
}
@@ -56,14 +55,14 @@ void testRunPassOnModule(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirPassManagerAddOwnedPass(pm, printOpStatPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running pass manager.\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
}
- mlirModuleDestroy(module);
+ mlirOperationDestroy(func);
mlirContextDestroy(ctx);
}
@@ -71,22 +70,23 @@ void testRunPassOnNestedModule(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
- MlirModule module = mlirModuleCreateParse(
- ctx,
- // clang-format off
- mlirStringRefCreateFromCString(
-"func.func @foo(%arg0 : i32) -> i32 { \n"
-" %res = arith.addi %arg0, %arg0 : i32 \n"
-" return %res : i32 \n"
-"} \n"
-"module { \n"
-" func.func @bar(%arg0 : f32) -> f32 { \n"
-" %res = arith.addf %arg0, %arg0 : f32 \n"
-" return %res : f32 \n"
-" } \n"
-"}"));
- // clang-format on
- if (mlirModuleIsNull(module))
+ const char *moduleAsm = //
+ "module { \n"
+ " func.func @foo(%arg0 : i32) -> i32 { \n"
+ " %res = arith.addi %arg0, %arg0 : i32 \n"
+ " return %res : i32 \n"
+ " } \n"
+ " module { \n"
+ " func.func @bar(%arg0 : f32) -> f32 { \n"
+ " %res = arith.addf %arg0, %arg0 : f32 \n"
+ " return %res : f32 \n"
+ " } \n"
+ " } \n"
+ "} \n";
+ MlirOperation module =
+ mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
+ mlirStringRefCreateFromCString("moduleAsm"));
+ if (mlirOperationIsNull(module))
exit(1);
// Run the print-op-stats pass on functions under the top-level module:
@@ -100,7 +100,7 @@ void testRunPassOnNestedModule(void) {
pm, mlirStringRefCreateFromCString("func.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
@@ -118,13 +118,13 @@ void testRunPassOnNestedModule(void) {
nestedModulePm, mlirStringRefCreateFromCString("func.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
}
- mlirModuleDestroy(module);
+ mlirOperationDestroy(module);
mlirContextDestroy(ctx);
}
@@ -339,16 +339,17 @@ void testExternalPass(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
- MlirModule module = mlirModuleCreateParse(
- ctx,
- // clang-format off
- mlirStringRefCreateFromCString(
-"func.func @foo(%arg0 : i32) -> i32 { \n"
-" %res = arith.addi %arg0, %arg0 : i32 \n"
-" return %res : i32 \n"
-"}"));
- // clang-format on
- if (mlirModuleIsNull(module)) {
+ const char *moduleAsm = //
+ "module { \n"
+ " func.func @foo(%arg0 : i32) -> i32 { \n"
+ " %res = arith.addi %arg0, %arg0 : i32 \n"
+ " return %res : i32 \n"
+ " } \n"
+ "}";
+ MlirOperation module =
+ mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
+ mlirStringRefCreateFromCString("moduleAsm"));
+ if (mlirOperationIsNull(module)) {
fprintf(stderr, "Unexpected failure parsing module.\n");
exit(EXIT_FAILURE);
}
@@ -377,7 +378,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
@@ -421,7 +422,7 @@ void testExternalPass(void) {
MlirOpPassManager nestedFuncPm =
mlirPassManagerGetNestedUnder(pm, funcOpName);
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external operation pass.\n");
exit(EXIT_FAILURE);
@@ -469,7 +470,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
@@ -516,7 +517,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
@@ -564,7 +565,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
- MlirLogicalResult success = mlirPassManagerRun(pm, module);
+ MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
@@ -587,7 +588,7 @@ void testExternalPass(void) {
}
mlirTypeIDAllocatorDestroy(typeIDAllocator);
- mlirModuleDestroy(module);
+ mlirOperationDestroy(module);
mlirContextDestroy(ctx);
}
More information about the Mlir-commits
mailing list