[Mlir-commits] [mlir] sketch (PR #73368)

Maksim Levental llvmlistbot at llvm.org
Fri Nov 24 12:38:19 PST 2023


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/73368

None

>From 36aae61690752b8c56f332e4c4c1b85dd6380f5d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 24 Nov 2023 14:37:37 -0600
Subject: [PATCH] sketch

---
 mlir/include/mlir-c/Pass.h        |  4 +-
 mlir/lib/Bindings/Python/Pass.cpp | 74 +++++++++++++++++++++++++++++++
 mlir/lib/CAPI/IR/Pass.cpp         |  4 +-
 3 files changed, 78 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1e22..ada706c4bc488ce 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -153,7 +153,7 @@ struct MlirExternalPassCallbacks {
   /// The callback is called before the pass is run, allowing a chance to
   /// initialize any complex state necessary for running the pass.
   /// See Pass::initialize(MLIRContext *).
-  MlirLogicalResult (*initialize)(MlirContext ctx, void *userData);
+  MlirLogicalResult (*initialize)(void *userData, MlirContext ctx);
 
   /// This callback is called when the pass is cloned.
   /// See Pass::clonePass().
@@ -161,7 +161,7 @@ struct MlirExternalPassCallbacks {
 
   /// This callback is called when the pass is run.
   /// See Pass::runOnOperation().
-  void (*run)(MlirOperation op, MlirExternalPass pass, void *userData);
+  void (*run)(void *userData, MlirOperation op, MlirExternalPass pass);
 };
 typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
 
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 588a8e25414c657..bda211b98e57ca5 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -51,6 +51,55 @@ class PyPassManager {
 
 } // namespace
 
+class PythonPass {
+public:
+  explicit PythonPass(py::object passObj) : passObj(std::move(passObj)) {}
+
+  void *construct() {}
+  void *destruct() {}
+
+  MlirLogicalResult *initialize(MlirContext ctx) {}
+  void *clone() {}
+
+  void run(MlirOperation op, MlirExternalPass pass) {}
+
+  py::object passObj;
+};
+
+template <typename T, typename R>
+void *void_cast(R (T::*f)()) {
+  union {
+    R (T::*pf)();
+    void *p;
+  };
+  pf = f;
+  return p;
+}
+
+template <typename classT, typename memberT>
+union u_ptm_cast {
+  memberT pmember;
+  void *pvoid;
+};
+
+MlirExternalPassCallbacks makeTestExternalPassCallbacks() {
+  return (MlirExternalPassCallbacks){
+      reinterpret_cast<decltype(MlirExternalPassCallbacks::construct)>(
+          void_cast(&PythonPass::construct)),
+      reinterpret_cast<decltype(MlirExternalPassCallbacks::destruct)>(
+          void_cast(&PythonPass::destruct)),
+      nullptr,
+      reinterpret_cast<decltype(MlirExternalPassCallbacks::clone)>(
+          void_cast(&PythonPass::clone)),
+      reinterpret_cast<decltype(MlirExternalPassCallbacks::run)>(
+          u_ptm_cast<PythonPass,
+                     void (PythonPass::*)(MlirOperation, MlirExternalPass)>{
+              &PythonPass::run}
+              .pvoid),
+
+  };
+}
+
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(py::module &m) {
   //----------------------------------------------------------------------------
@@ -114,6 +163,27 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
+      .def_static(
+          "create_external_pass",
+          [](py::object &passObj) {
+            PythonPass pass = PythonPass(passObj);
+
+            MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+            MlirTypeID passID =
+                mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+            MlirStringRef name =
+                mlirStringRefCreateFromCString("TestExternalPass");
+            MlirStringRef description = mlirStringRefCreateFromCString("");
+            MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
+            MlirStringRef argument =
+                mlirStringRefCreateFromCString("test-external-pass");
+
+            auto cbs = makeTestExternalPassCallbacks();
+
+            MlirPass externalPass =
+                mlirCreateExternalPass(passID, name, argument, description,
+                                       emptyOpName, 0, NULL, cbs, &pass);
+          })
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op,
@@ -144,4 +214,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           },
           "Print the textual representation for this PassManager, suitable to "
           "be passed to `parse` for round-tripping.");
+
+  py::class_<PythonPass>(m, "PythonPass", py::module_local())
+      .def(py::init<>(
+          [](py::object pass) { return PythonPass(std::move(pass)); }));
 }
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c0862..68b9035848f7575 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -138,7 +138,7 @@ class ExternalPass : public Pass {
 protected:
   LogicalResult initialize(MLIRContext *ctx) override {
     if (callbacks.initialize)
-      return unwrap(callbacks.initialize(wrap(ctx), userData));
+      return unwrap(callbacks.initialize(userData, wrap(ctx)));
     return success();
   }
 
@@ -149,7 +149,7 @@ class ExternalPass : public Pass {
   }
 
   void runOnOperation() override {
-    callbacks.run(wrap(getOperation()), wrap(this), userData);
+    callbacks.run(userData, wrap(getOperation()), wrap(this));
   }
 
   std::unique_ptr<Pass> clonePass() const override {



More information about the Mlir-commits mailing list