[Mlir-commits] [mlir] [mlir][capi] make MLIR Pass C-API type safe (PR #121284)
Jiuyang Liu
llvmlistbot at llvm.org
Sat Dec 28 23:32:27 PST 2024
https://github.com/sequencer updated https://github.com/llvm/llvm-project/pull/121284
>From 91001e757595061766bc5106e1c7d156bd34c372 Mon Sep 17 00:00:00 2001
From: Jiuyang Liu <liu at jiuyang.me>
Date: Sun, 29 Dec 2024 15:17:12 +0800
Subject: [PATCH] [mlir][capi] make MLIR Pass C-API type safe
- change C-API name mlirCreateExternalPass to mlirExternalPassCreate for
aligning the C-API naming convension;
- make mlirExternalPassCreate to return MlirExternalPass for type safe;
- create new C-API MlirExternalPassGetPass to cast MlirExternalPass to
MlirPass;
---
mlir/include/mlir-c/Pass.h | 5 ++++-
mlir/lib/CAPI/IR/Pass.cpp | 22 +++++++++++++---------
mlir/test/CAPI/pass.c | 30 ++++++++++++++++++++----------
3 files changed, 37 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 8fd8e9956a65a3..acac29a513a386 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -174,12 +174,15 @@ typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
/// Creates an external `MlirPass` that calls the supplied `callbacks` using the
/// supplied `userData`. If `opName` is empty, the pass is a generic operation
/// pass. Otherwise it is an operation pass specific to the specified pass name.
-MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
+MLIR_CAPI_EXPORTED MlirExternalPass mlirExternalPassCreate(
MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
MlirStringRef description, MlirStringRef opName,
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks, void *userData);
+// Static cast ExternalPass to Pass.
+MLIR_CAPI_EXPORTED MlirPass mlirExternalPassGetPass(MlirExternalPass pass);
+
/// This signals that the pass has failed. This is only valid to call during
/// the `run` callback of `MlirExternalPassCallbacks`.
/// See Pass::signalPassFailure().
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 883b7e8bb832d2..38183f04fc98c1 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -193,19 +193,23 @@ class ExternalPass : public Pass {
};
} // namespace mlir
-MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
- MlirStringRef argument,
- MlirStringRef description, MlirStringRef opName,
- intptr_t nDependentDialects,
- MlirDialectHandle *dependentDialects,
- MlirExternalPassCallbacks callbacks,
- void *userData) {
- return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
+MlirExternalPass mlirExternalPassCreate(MlirTypeID passID, MlirStringRef name,
+ MlirStringRef argument,
+ MlirStringRef description, MlirStringRef opName,
+ intptr_t nDependentDialects,
+ MlirDialectHandle *dependentDialects,
+ MlirExternalPassCallbacks callbacks,
+ void *userData) {
+ return wrap(new mlir::ExternalPass(
unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
: std::nullopt,
{dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
- userData)));
+ userData));
+}
+
+MlirPass mlirExternalPassGetPass(MlirExternalPass externalPass) {
+ return wrap(static_cast<mlir::Pass>(externalPass));
}
void mlirExternalPassSignalFailure(MlirExternalPass pass) {
diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 3aad0016b393c4..8778c945ab2596 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -367,17 +367,19 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-pass");
TestExternalPassUserData userData = {0};
- MlirPass externalPass = mlirCreateExternalPass(
+ MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
+ MlirPass pass = mlirExternalPassGetPass(externalPass);
+
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
- mlirPassManagerAddOwnedPass(pm, externalPass);
+ mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
@@ -408,11 +410,13 @@ void testExternalPass(void) {
MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
- MlirPass externalPass = mlirCreateExternalPass(
+ MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, funcOpName, 1, &funcHandle,
makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
&userData);
+ MlirPass pass = mlirExternalPassGetPass(externalPass);
+
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
@@ -421,7 +425,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirOpPassManager nestedFuncPm =
mlirPassManagerGetNestedUnder(pm, funcOpName);
- mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
+ mlirOpPassManagerAddOwnedPass(nestedFuncPm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external operation pass.\n");
@@ -457,19 +461,21 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-pass");
TestExternalPassUserData userData = {0};
- MlirPass externalPass = mlirCreateExternalPass(
+ MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(testInitializeExternalPass,
testRunExternalPass),
&userData);
+ MlirPass pass = mlirExternalPassGetPass(externalPass);
+
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
- mlirPassManagerAddOwnedPass(pm, externalPass);
+ mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
@@ -504,19 +510,21 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-failing-pass");
TestExternalPassUserData userData = {0};
- MlirPass externalPass = mlirCreateExternalPass(
+ MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
testRunExternalPass),
&userData);
+ MlirPass pass = mlirExternalPassGetPass(externalPass);
+
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
- mlirPassManagerAddOwnedPass(pm, externalPass);
+ mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
@@ -553,18 +561,20 @@ void testExternalPass(void) {
mlirStringRefCreateFromCString("test-external-failing-pass");
TestExternalPassUserData userData = {0};
- MlirPass externalPass = mlirCreateExternalPass(
+ MlirExternalPass externalPass = mlirExternalPassCreate(
passID, name, argument, description, emptyOpName, 0, NULL,
makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
&userData);
+ MlirPass pass = mlirExternalPassGetPass(externalPass);
+
if (userData.constructCallCount != 1) {
fprintf(stderr, "Expected constructCallCount to be 1\n");
exit(EXIT_FAILURE);
}
MlirPassManager pm = mlirPassManagerCreate(ctx);
- mlirPassManagerAddOwnedPass(pm, externalPass);
+ mlirPassManagerAddOwnedPass(pm, pass);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
More information about the Mlir-commits
mailing list