[Mlir-commits] [mlir] sketch (PR #73368)
Maksim Levental
llvmlistbot at llvm.org
Fri Nov 24 12:38:19 PST 2023
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/73368
None
>From 36aae61690752b8c56f332e4c4c1b85dd6380f5d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 24 Nov 2023 14:37:37 -0600
Subject: [PATCH] sketch
---
mlir/include/mlir-c/Pass.h | 4 +-
mlir/lib/Bindings/Python/Pass.cpp | 74 +++++++++++++++++++++++++++++++
mlir/lib/CAPI/IR/Pass.cpp | 4 +-
3 files changed, 78 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1e22..ada706c4bc488ce 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -153,7 +153,7 @@ struct MlirExternalPassCallbacks {
/// The callback is called before the pass is run, allowing a chance to
/// initialize any complex state necessary for running the pass.
/// See Pass::initialize(MLIRContext *).
- MlirLogicalResult (*initialize)(MlirContext ctx, void *userData);
+ MlirLogicalResult (*initialize)(void *userData, MlirContext ctx);
/// This callback is called when the pass is cloned.
/// See Pass::clonePass().
@@ -161,7 +161,7 @@ struct MlirExternalPassCallbacks {
/// This callback is called when the pass is run.
/// See Pass::runOnOperation().
- void (*run)(MlirOperation op, MlirExternalPass pass, void *userData);
+ void (*run)(void *userData, MlirOperation op, MlirExternalPass pass);
};
typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 588a8e25414c657..bda211b98e57ca5 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -51,6 +51,55 @@ class PyPassManager {
} // namespace
+class PythonPass {
+public:
+ explicit PythonPass(py::object passObj) : passObj(std::move(passObj)) {}
+
+ void *construct() {}
+ void *destruct() {}
+
+ MlirLogicalResult *initialize(MlirContext ctx) {}
+ void *clone() {}
+
+ void run(MlirOperation op, MlirExternalPass pass) {}
+
+ py::object passObj;
+};
+
+template <typename T, typename R>
+void *void_cast(R (T::*f)()) {
+ union {
+ R (T::*pf)();
+ void *p;
+ };
+ pf = f;
+ return p;
+}
+
+template <typename classT, typename memberT>
+union u_ptm_cast {
+ memberT pmember;
+ void *pvoid;
+};
+
+MlirExternalPassCallbacks makeTestExternalPassCallbacks() {
+ return (MlirExternalPassCallbacks){
+ reinterpret_cast<decltype(MlirExternalPassCallbacks::construct)>(
+ void_cast(&PythonPass::construct)),
+ reinterpret_cast<decltype(MlirExternalPassCallbacks::destruct)>(
+ void_cast(&PythonPass::destruct)),
+ nullptr,
+ reinterpret_cast<decltype(MlirExternalPassCallbacks::clone)>(
+ void_cast(&PythonPass::clone)),
+ reinterpret_cast<decltype(MlirExternalPassCallbacks::run)>(
+ u_ptm_cast<PythonPass,
+ void (PythonPass::*)(MlirOperation, MlirExternalPass)>{
+ &PythonPass::run}
+ .pvoid),
+
+ };
+}
+
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(py::module &m) {
//----------------------------------------------------------------------------
@@ -114,6 +163,27 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
+ .def_static(
+ "create_external_pass",
+ [](py::object &passObj) {
+ PythonPass pass = PythonPass(passObj);
+
+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+ MlirTypeID passID =
+ mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirStringRef name =
+ mlirStringRefCreateFromCString("TestExternalPass");
+ MlirStringRef description = mlirStringRefCreateFromCString("");
+ MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
+ MlirStringRef argument =
+ mlirStringRefCreateFromCString("test-external-pass");
+
+ auto cbs = makeTestExternalPassCallbacks();
+
+ MlirPass externalPass =
+ mlirCreateExternalPass(passID, name, argument, description,
+ emptyOpName, 0, NULL, cbs, &pass);
+ })
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
@@ -144,4 +214,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
+
+ py::class_<PythonPass>(m, "PythonPass", py::module_local())
+ .def(py::init<>(
+ [](py::object pass) { return PythonPass(std::move(pass)); }));
}
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c0862..68b9035848f7575 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -138,7 +138,7 @@ class ExternalPass : public Pass {
protected:
LogicalResult initialize(MLIRContext *ctx) override {
if (callbacks.initialize)
- return unwrap(callbacks.initialize(wrap(ctx), userData));
+ return unwrap(callbacks.initialize(userData, wrap(ctx)));
return success();
}
@@ -149,7 +149,7 @@ class ExternalPass : public Pass {
}
void runOnOperation() override {
- callbacks.run(wrap(getOperation()), wrap(this), userData);
+ callbacks.run(userData, wrap(getOperation()), wrap(this));
}
std::unique_ptr<Pass> clonePass() const override {
More information about the Mlir-commits
mailing list