[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined passes in MLIR (PR #156000)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 29 09:32:01 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/156000
>From 8386c87c431585c4412a52d07287b7423e34f602 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 29 Aug 2025 17:43:04 +0800
Subject: [PATCH 1/6] [MLIR][Python] Support Python-defined passes in MLIR
---
mlir/lib/Bindings/Python/MainModule.cpp | 1 +
mlir/lib/Bindings/Python/Pass.cpp | 80 ++++++++++++++++++++++++-
mlir/lib/Bindings/Python/Pass.h | 1 +
3 files changed, 81 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..590e862a8d358 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);
+ populatePassSubmodule(passModule);
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..4aa93df938295 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -9,9 +9,11 @@
#include "Pass.h"
#include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
+#include "llvm/Support/ErrorHandling.h"
namespace nb = nanobind;
using namespace nb::literals;
@@ -20,6 +22,63 @@ using namespace mlir::python;
namespace {
+// A base class for defining passes in Python
+// Users are expected to subclass this and implement the `run` method, e.g.
+// ```
+// class MyPass(mlir.passmanager.Pass):
+// def run(self, operation):
+// # do something with operation
+// pass
+// ```
+class PyPassBase {
+public:
+ PyPassBase() : callbacks{} {
+ callbacks.construct = [](void *) {};
+ callbacks.destruct = [](void *) {};
+ callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
+ static_cast<PyPassBase *>(obj)->run(op);
+ };
+ // TODO: currently we don't support pass cloning in python
+ // due to lifetime management issues.
+ callbacks.clone = [](void *obj) -> void * {
+ // since the caller here should be MLIR C++ code,
+ // we need to avoid using exceptions like throw py::value_error(...).
+ llvm_unreachable("cloning of python-defined passes is not supported");
+ };
+ }
+
+ // this method should be overridden by subclasses in Python.
+ virtual void run(MlirOperation op) = 0;
+
+ virtual ~PyPassBase() = default;
+
+ // Make an MlirPass instance on-the-fly that wraps this object.
+ // Note that passmanager will take the ownership of the returned
+ // object and release it when appropriate.
+ // Also, `*this` must remain alive as long as the returned object is alive.
+ MlirPass make() {
+ return mlirCreateExternalPass(
+ mlirTypeIDCreate(this),
+ mlirStringRefCreateFromCString("python-example-pass"),
+ mlirStringRefCreateFromCString(""),
+ mlirStringRefCreateFromCString("Python Example Pass"),
+ mlirStringRefCreateFromCString(""), 0, nullptr, callbacks, this);
+ }
+
+private:
+ MlirExternalPassCallbacks callbacks;
+};
+
+// A trampoline class upon PyPassBase.
+// Refer to
+// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
+class PyPass : PyPassBase {
+public:
+ NB_TRAMPOLINE(PyPassBase, 1);
+
+ void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
+};
+
/// Owning Wrapper around a PassManager.
class PyPassManager {
public:
@@ -52,6 +111,16 @@ class PyPassManager {
} // namespace
+void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of the Python-defined Pass interface
+ //----------------------------------------------------------------------------
+ nb::class_<PyPassBase, PyPass>(m, "Pass")
+ .def(nb::init<>(), "Create a new Pass.")
+ .def("run", &PyPassBase::run, "operation"_a,
+ "Run the pass on the provided operation.");
+}
+
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
@@ -157,6 +226,15 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
+ .def(
+ "add",
+ [](PyPassManager &passManager, PyPassBase &pass) {
+ mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
+ },
+ "pass"_a, "Add a python-defined pass to the pass manager.",
+ // NOTE that we should keep the pass object alive as long as the
+ // passManager to prevent dangling objects.
+ nb::keep_alive<1, 2>())
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..ba3fbb707fed7 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -15,6 +15,7 @@ namespace mlir {
namespace python {
void populatePassManagerSubmodule(nanobind::module_ &m);
+void populatePassSubmodule(nanobind::module_ &m);
} // namespace python
} // namespace mlir
>From cb826212272d3276350019d40703113a2f30e983 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 29 Aug 2025 18:39:17 +0800
Subject: [PATCH 2/6] add ctor with args for Pass
---
mlir/lib/Bindings/Python/Pass.cpp | 40 ++++++++++++++++++++++++-------
1 file changed, 32 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 4aa93df938295..898d3c096c1d8 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,6 +11,7 @@
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/Pass.h"
+#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "nanobind/trampoline.h"
#include "llvm/Support/ErrorHandling.h"
@@ -32,7 +33,10 @@ namespace {
// ```
class PyPassBase {
public:
- PyPassBase() : callbacks{} {
+ PyPassBase(std::string name, std::string argument, std::string description,
+ std::string opName)
+ : callbacks{}, name(std::move(name)), argument(std::move(argument)),
+ description(std::move(description)), opName(std::move(opName)) {
callbacks.construct = [](void *) {};
callbacks.destruct = [](void *) {};
callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
@@ -58,15 +62,25 @@ class PyPassBase {
// Also, `*this` must remain alive as long as the returned object is alive.
MlirPass make() {
return mlirCreateExternalPass(
- mlirTypeIDCreate(this),
- mlirStringRefCreateFromCString("python-example-pass"),
- mlirStringRefCreateFromCString(""),
- mlirStringRefCreateFromCString("Python Example Pass"),
- mlirStringRefCreateFromCString(""), 0, nullptr, callbacks, this);
+ mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
+ mlirStringRefCreate(argument.data(), argument.length()),
+ mlirStringRefCreate(description.data(), description.length()),
+ mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
+ callbacks, this);
}
+ const std::string &getName() const { return name; }
+ const std::string &getArgument() const { return argument; }
+ const std::string &getDescription() const { return description; }
+ const std::string &getOpName() const { return opName; }
+
private:
MlirExternalPassCallbacks callbacks;
+
+ std::string name;
+ std::string argument;
+ std::string description;
+ std::string opName;
};
// A trampoline class upon PyPassBase.
@@ -116,9 +130,19 @@ void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
// Mapping of the Python-defined Pass interface
//----------------------------------------------------------------------------
nb::class_<PyPassBase, PyPass>(m, "Pass")
- .def(nb::init<>(), "Create a new Pass.")
+ .def(nb::init<std::string, std::string, std::string, std::string>(),
+ "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
+ "op_name"_a = "", "Create a new Pass.")
.def("run", &PyPassBase::run, "operation"_a,
- "Run the pass on the provided operation.");
+ "Run the pass on the provided operation.")
+ .def_prop_ro("name",
+ [](const PyPassBase &self) { return self.getName(); })
+ .def_prop_ro("argument",
+ [](const PyPassBase &self) { return self.getArgument(); })
+ .def_prop_ro("description",
+ [](const PyPassBase &self) { return self.getDescription(); })
+ .def_prop_ro("op_name",
+ [](const PyPassBase &self) { return self.getOpName(); });
}
/// Create the `mlir.passmanager` here.
>From 7556ca2ad0deb938765034cda9c49edd3c7c5975 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 29 Aug 2025 22:26:40 +0800
Subject: [PATCH 3/6] fix lifetime issue
---
mlir/lib/Bindings/Python/Pass.cpp | 28 ++++++++++++++--------------
1 file changed, 14 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 898d3c096c1d8..babb8a723ca04 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -35,20 +35,22 @@ class PyPassBase {
public:
PyPassBase(std::string name, std::string argument, std::string description,
std::string opName)
- : callbacks{}, name(std::move(name)), argument(std::move(argument)),
+ : name(std::move(name)), argument(std::move(argument)),
description(std::move(description)), opName(std::move(opName)) {
- callbacks.construct = [](void *) {};
- callbacks.destruct = [](void *) {};
+ callbacks.construct = [](void *obj) {};
+ callbacks.destruct = [](void *obj) {
+ nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+ };
callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
- static_cast<PyPassBase *>(obj)->run(op);
+ auto handle = nb::handle(static_cast<PyObject *>(obj));
+ nb::cast<PyPassBase *>(handle)->run(op);
};
- // TODO: currently we don't support pass cloning in python
- // due to lifetime management issues.
callbacks.clone = [](void *obj) -> void * {
- // since the caller here should be MLIR C++ code,
- // we need to avoid using exceptions like throw py::value_error(...).
- llvm_unreachable("cloning of python-defined passes is not supported");
+ nb::object copy = nb::module_::import_("copy");
+ nb::object deepcopy = copy.attr("deepcopy");
+ return deepcopy(obj).release().ptr();
};
+ callbacks.initialize = nullptr;
}
// this method should be overridden by subclasses in Python.
@@ -61,12 +63,13 @@ class PyPassBase {
// object and release it when appropriate.
// Also, `*this` must remain alive as long as the returned object is alive.
MlirPass make() {
+ auto *obj = nb::find(this).release().ptr();
return mlirCreateExternalPass(
mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
mlirStringRefCreate(argument.data(), argument.length()),
mlirStringRefCreate(description.data(), description.length()),
mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
- callbacks, this);
+ callbacks, obj);
}
const std::string &getName() const { return name; }
@@ -255,10 +258,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
[](PyPassManager &passManager, PyPassBase &pass) {
mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
},
- "pass"_a, "Add a python-defined pass to the pass manager.",
- // NOTE that we should keep the pass object alive as long as the
- // passManager to prevent dangling objects.
- nb::keep_alive<1, 2>())
+ "pass"_a, "Add a python-defined pass to the pass manager.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
>From d5055a79a683a7cd71533db7bed537d6028f1950 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 00:20:59 +0800
Subject: [PATCH 4/6] add test case
---
mlir/include/mlir-c/Rewrite.h | 4 ++
mlir/lib/Bindings/Python/Pass.cpp | 4 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 31 ++++++++----
mlir/lib/CAPI/Transforms/Rewrite.cpp | 7 +++
mlir/test/python/pass.py | 73 ++++++++++++++++++++++++++++
5 files changed, 106 insertions(+), 13 deletions(-)
create mode 100644 mlir/test/python/pass.py
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..21ae236d6f73f 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
+ MlirOperation op, MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig);
+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index babb8a723ca04..920d604d24680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -9,12 +9,10 @@
#include "Pass.h"
#include "IRModule.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/Pass.h"
-#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "nanobind/trampoline.h"
-#include "llvm/Support/ErrorHandling.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
namespace nb = nanobind;
using namespace nb::literals;
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..675bd685ec2db 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
- "apply_patterns_and_fold_greedily",
- [](MlirModule module, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
- if (mlirLogicalResultIsFailure(status))
- // FIXME: Not sure this is the right error to throw here.
- throw nb::value_error("pattern application failed to converge");
- },
- "module"_a, "set"_a,
- "Applys the given patterns to the given module greedily while folding "
- "results.");
+ "apply_patterns_and_fold_greedily",
+ [](MlirModule module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ // FIXME: Not sure this is the right error to throw here.
+ throw nb::value_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.")
+ .def(
+ "apply_patterns_and_fold_greedily_for_op",
+ [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ // FIXME: Not sure this is the right error to throw here.
+ throw nb::value_error("pattern application failed to converge");
+ },
+ "op"_a, "set"_a,
+ "Applys the given patterns to the given op greedily while folding "
+ "results.");
}
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..d606445cfad31 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
new file mode 100644
index 0000000000000..32943caec19a4
--- /dev/null
+++ b/mlir/test/python/pass.py
@@ -0,0 +1,73 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+def make_pdl_module():
+ with Location.unknown():
+ pdl_module = Module.create()
+ with InsertionPoint(pdl_module.body):
+ # Change all arith.addi with index types to arith.muli.
+ @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+ def pat():
+ # Match arith.addi with index types.
+ index_type = pdl.TypeOp(IndexType.get())
+ operand0 = pdl.OperandOp(index_type)
+ operand1 = pdl.OperandOp(index_type)
+ op0 = pdl.OperationOp(
+ name="arith.addi", args=[operand0, operand1], types=[index_type]
+ )
+
+ # Replace the matched op with arith.muli.
+ @pdl.rewrite()
+ def rew():
+ newOp = pdl.OperationOp(
+ name="arith.muli", args=[operand0, operand1], types=[index_type]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ return pdl_module
+
+# CHECK-LABEL: TEST: testCustomPass
+ at run
+def testCustomPass():
+ with Context():
+ pdl_module = make_pdl_module()
+
+ class CustomPass(Pass):
+ def __init__(self):
+ super().__init__("CustomPass", op_name="builtin.module")
+ def run(self, m):
+ frozen = PDLModule(pdl_module).freeze()
+ apply_patterns_and_fold_greedily_for_op(m, frozen)
+
+ module = ModuleOp.parse(r"""
+ module {
+ func.func @add(%a: index, %b: index) -> index {
+ %sum = arith.addi %a, %b : index
+ return %sum : index
+ }
+ }
+ """)
+
+ # CHECK-LABEL: Dump After CustomPass
+ # CHECK: arith.muli
+ pm = PassManager('any')
+ pm.enable_ir_printing()
+ pm.add(CustomPass())
+ pm.run(module)
>From 6d2f4720b5c724a5a65d257e05b60ea7adae9569 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 00:31:05 +0800
Subject: [PATCH 5/6] fix header
---
mlir/lib/Bindings/Python/Pass.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 920d604d24680..cda1f7af243d3 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,8 +11,8 @@
#include "IRModule.h"
#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "nanobind/trampoline.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
namespace nb = nanobind;
using namespace nb::literals;
>From 1a98ae84df54f1d441a063dd267ee70b574d3b22 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 00:31:46 +0800
Subject: [PATCH 6/6] format
---
mlir/test/python/pass.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
index 32943caec19a4..4389e3bf38686 100644
--- a/mlir/test/python/pass.py
+++ b/mlir/test/python/pass.py
@@ -7,6 +7,7 @@
from mlir.dialects import pdl
from mlir.rewrite import *
+
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()
@@ -18,6 +19,7 @@ def run(f):
gc.collect()
assert Context._get_live_count() == 0
+
def make_pdl_module():
with Location.unknown():
pdl_module = Module.create()
@@ -43,6 +45,7 @@ def rew():
return pdl_module
+
# CHECK-LABEL: TEST: testCustomPass
@run
def testCustomPass():
More information about the Mlir-commits
mailing list