[Mlir-commits] [mlir] [mlir][capi] make MLIR Pass C-API type safe (PR #121284)

Jiuyang Liu llvmlistbot at llvm.org
Sun Dec 29 01:18:36 PST 2024


https://github.com/sequencer updated https://github.com/llvm/llvm-project/pull/121284

>From c7331b086520b075855911b2b584a9dd9785c6a9 Mon Sep 17 00:00:00 2001
From: Jiuyang Liu <liu at jiuyang.me>
Date: Sun, 29 Dec 2024 15:17:12 +0800
Subject: [PATCH] [mlir][capi] make MLIR Pass C-API type safe

- change C-API name mlirCreateExternalPass to mlirExternalPassCreate for
  aligning the C-API naming convension;
- make mlirExternalPassCreate to return MlirExternalPass for type safety;
- create new C-API MlirExternalPassGetPass to cast MlirExternalPass to
  MlirPass;
---
 mlir/include/mlir-c/Pass.h |  5 ++++-
 mlir/lib/CAPI/IR/Pass.cpp  | 22 +++++++++++++---------
 mlir/test/CAPI/pass.c      | 30 ++++++++++++++++++++----------
 3 files changed, 37 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 8fd8e9956a65a3..acac29a513a386 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -174,12 +174,15 @@ typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
 /// Creates an external `MlirPass` that calls the supplied `callbacks` using the
 /// supplied `userData`. If `opName` is empty, the pass is a generic operation
 /// pass. Otherwise it is an operation pass specific to the specified pass name.
-MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
+MLIR_CAPI_EXPORTED MlirExternalPass mlirExternalPassCreate(
     MlirTypeID passID, MlirStringRef name, MlirStringRef argument,
     MlirStringRef description, MlirStringRef opName,
     intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
     MlirExternalPassCallbacks callbacks, void *userData);
 
+// Static cast ExternalPass to Pass.
+MLIR_CAPI_EXPORTED MlirPass mlirExternalPassGetPass(MlirExternalPass pass);
+
 /// This signals that the pass has failed. This is only valid to call during
 /// the `run` callback of `MlirExternalPassCallbacks`.
 /// See Pass::signalPassFailure().
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 883b7e8bb832d2..c751d5c467e4c1 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -193,19 +193,23 @@ class ExternalPass : public Pass {
 };
 } // namespace mlir
 
-MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
-                                MlirStringRef argument,
-                                MlirStringRef description, MlirStringRef opName,
-                                intptr_t nDependentDialects,
-                                MlirDialectHandle *dependentDialects,
-                                MlirExternalPassCallbacks callbacks,
-                                void *userData) {
-  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
+MlirExternalPass mlirExternalPassCreate(MlirTypeID passID, MlirStringRef name,
+                                        MlirStringRef argument,
+                                        MlirStringRef description, MlirStringRef opName,
+                                        intptr_t nDependentDialects,
+                                        MlirDialectHandle *dependentDialects,
+                                        MlirExternalPassCallbacks callbacks,
+                                        void *userData) {
+  return wrap(new mlir::ExternalPass(
       unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
       opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
                         : std::nullopt,
       {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
-      userData)));
+      userData));
+}
+
+MlirPass mlirExternalPassGetPass(MlirExternalPass externalPass) {
+  return wrap(static_cast<mlir::Pass *>(&externalPass));
 }
 
 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 3aad0016b393c4..8778c945ab2596 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -367,17 +367,19 @@ void testExternalPass(void) {
         mlirStringRefCreateFromCString("test-external-pass");
     TestExternalPassUserData userData = {0};
 
-    MlirPass externalPass = mlirCreateExternalPass(
+    MlirExternalPass externalPass = mlirExternalPassCreate(
         passID, name, argument, description, emptyOpName, 0, NULL,
         makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
 
+    MlirPass pass = mlirExternalPassGetPass(externalPass);
+
     if (userData.constructCallCount != 1) {
       fprintf(stderr, "Expected constructCallCount to be 1\n");
       exit(EXIT_FAILURE);
     }
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
-    mlirPassManagerAddOwnedPass(pm, externalPass);
+    mlirPassManagerAddOwnedPass(pm, pass);
     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success)) {
       fprintf(stderr, "Unexpected failure running external pass.\n");
@@ -408,11 +410,13 @@ void testExternalPass(void) {
     MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
     MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
 
-    MlirPass externalPass = mlirCreateExternalPass(
+    MlirExternalPass externalPass = mlirExternalPassCreate(
         passID, name, argument, description, funcOpName, 1, &funcHandle,
         makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
         &userData);
 
+    MlirPass pass = mlirExternalPassGetPass(externalPass);
+
     if (userData.constructCallCount != 1) {
       fprintf(stderr, "Expected constructCallCount to be 1\n");
       exit(EXIT_FAILURE);
@@ -421,7 +425,7 @@ void testExternalPass(void) {
     MlirPassManager pm = mlirPassManagerCreate(ctx);
     MlirOpPassManager nestedFuncPm =
         mlirPassManagerGetNestedUnder(pm, funcOpName);
-    mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
+    mlirOpPassManagerAddOwnedPass(nestedFuncPm, pass);
     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success)) {
       fprintf(stderr, "Unexpected failure running external operation pass.\n");
@@ -457,19 +461,21 @@ void testExternalPass(void) {
         mlirStringRefCreateFromCString("test-external-pass");
     TestExternalPassUserData userData = {0};
 
-    MlirPass externalPass = mlirCreateExternalPass(
+    MlirExternalPass externalPass = mlirExternalPassCreate(
         passID, name, argument, description, emptyOpName, 0, NULL,
         makeTestExternalPassCallbacks(testInitializeExternalPass,
                                       testRunExternalPass),
         &userData);
 
+    MlirPass pass = mlirExternalPassGetPass(externalPass);
+
     if (userData.constructCallCount != 1) {
       fprintf(stderr, "Expected constructCallCount to be 1\n");
       exit(EXIT_FAILURE);
     }
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
-    mlirPassManagerAddOwnedPass(pm, externalPass);
+    mlirPassManagerAddOwnedPass(pm, pass);
     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsFailure(success)) {
       fprintf(stderr, "Unexpected failure running external pass.\n");
@@ -504,19 +510,21 @@ void testExternalPass(void) {
         mlirStringRefCreateFromCString("test-external-failing-pass");
     TestExternalPassUserData userData = {0};
 
-    MlirPass externalPass = mlirCreateExternalPass(
+    MlirExternalPass externalPass = mlirExternalPassCreate(
         passID, name, argument, description, emptyOpName, 0, NULL,
         makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
                                       testRunExternalPass),
         &userData);
 
+    MlirPass pass = mlirExternalPassGetPass(externalPass);
+
     if (userData.constructCallCount != 1) {
       fprintf(stderr, "Expected constructCallCount to be 1\n");
       exit(EXIT_FAILURE);
     }
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
-    mlirPassManagerAddOwnedPass(pm, externalPass);
+    mlirPassManagerAddOwnedPass(pm, pass);
     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsSuccess(success)) {
       fprintf(
@@ -553,18 +561,20 @@ void testExternalPass(void) {
         mlirStringRefCreateFromCString("test-external-failing-pass");
     TestExternalPassUserData userData = {0};
 
-    MlirPass externalPass = mlirCreateExternalPass(
+    MlirExternalPass externalPass = mlirExternalPassCreate(
         passID, name, argument, description, emptyOpName, 0, NULL,
         makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
         &userData);
 
+    MlirPass pass = mlirExternalPassGetPass(externalPass);
+
     if (userData.constructCallCount != 1) {
       fprintf(stderr, "Expected constructCallCount to be 1\n");
       exit(EXIT_FAILURE);
     }
 
     MlirPassManager pm = mlirPassManagerCreate(ctx);
-    mlirPassManagerAddOwnedPass(pm, externalPass);
+    mlirPassManagerAddOwnedPass(pm, pass);
     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
     if (mlirLogicalResultIsSuccess(success)) {
       fprintf(



More information about the Mlir-commits mailing list