[Mlir-commits] [mlir] 7d04e37 - [MLIR][Python] Support Python-defined passes in MLIR (#156000)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 8 18:01:27 PDT 2025
Author: Twice
Date: 2025-09-08T18:01:23-07:00
New Revision: 7d04e3790483f8e2f168ec538682bc31d3011a0b
URL: https://github.com/llvm/llvm-project/commit/7d04e3790483f8e2f168ec538682bc31d3011a0b
DIFF: https://github.com/llvm/llvm-project/commit/7d04e3790483f8e2f168ec538682bc31d3011a0b.diff
LOG: [MLIR][Python] Support Python-defined passes in MLIR (#156000)
It closes #155996.
This PR added a method `add(callable, ..)` to
`mlir.passmanager.PassManager` to accept a callable object for defining
passes in the Python side.
This is a simple example of a Python-defined pass.
```python
from mlir.passmanager import PassManager
def demo_pass_1(op):
# do something with op
pass
class DemoPass:
def __init__(self, ...):
pass
def __call__(op):
# do something
pass
demo_pass_2 = DemoPass(..)
pm = PassManager('any', ctx)
pm.add(demo_pass_1)
pm.add(demo_pass_2)
pm.add("registered-passes")
pm.run(..)
```
---------
Co-authored-by: cnb.bsD2OPwAgEA <QejD2DJ2eEahUVy6Zg0aZI+cnb.bsD2OPwAgEA at noreply.cnb.cool>
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
Added:
mlir/test/python/python_pass.py
Modified:
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/Bindings/Python/Pass.cpp
mlir/lib/CAPI/IR/Pass.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..d7282b3d6f713 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
populateRewriteSubmodule(rewriteModule);
// Define and populate PassManager submodule.
- auto passModule =
+ auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
- populatePassManagerSubmodule(passModule);
+ populatePassManagerSubmodule(passManagerModule);
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 72bf8ed8f856f..6ee85e8a31492 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -10,8 +10,10 @@
#include "IRModule.h"
#include "mlir-c/Pass.h"
+// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
namespace nb = nanobind;
using namespace nb::literals;
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
+ .def(
+ "add",
+ [](PyPassManager &passManager, const nb::callable &run,
+ std::optional<std::string> &name, const std::string &argument,
+ const std::string &description, const std::string &opName) {
+ if (!name.has_value()) {
+ name = nb::cast<std::string>(
+ nb::borrow<nb::str>(run.attr("__name__")));
+ }
+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+ MlirTypeID passID =
+ mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirExternalPassCallbacks callbacks;
+ callbacks.construct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+ };
+ callbacks.destruct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+ };
+ callbacks.initialize = nullptr;
+ callbacks.clone = [](void *) -> void * {
+ throw std::runtime_error("Cloning Python passes not supported");
+ };
+ callbacks.run = [](MlirOperation op, MlirExternalPass,
+ void *userData) {
+ nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+ };
+ auto externalPass = mlirCreateExternalPass(
+ passID, mlirStringRefCreate(name->data(), name->length()),
+ mlirStringRefCreate(argument.data(), argument.length()),
+ mlirStringRefCreate(description.data(), description.length()),
+ mlirStringRefCreate(opName.data(), opName.size()),
+ /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
+ callbacks, /*userData*/ run.ptr());
+ mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
+ },
+ "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
+ "description"_a.none() = "", "op_name"_a.none() = "",
+ "Add a python-defined pass to the pass manager.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op) {
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 3c499c3e4974d..b0a6ec1ace3cc 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
: Pass(passID, opName), id(passID), name(name), argument(argument),
description(description), dependentDialects(dependentDialects),
callbacks(callbacks), userData(userData) {
- callbacks.construct(userData);
+ if (callbacks.construct)
+ callbacks.construct(userData);
}
- ~ExternalPass() override { callbacks.destruct(userData); }
+ ~ExternalPass() override {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
StringRef getName() const override { return name; }
StringRef getArgument() const override { return argument; }
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
new file mode 100644
index 0000000000000..c94f96e20966f
--- /dev/null
+++ b/mlir/test/python/python_pass.py
@@ -0,0 +1,88 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+def make_pdl_module():
+ with Location.unknown():
+ pdl_module = Module.create()
+ with InsertionPoint(pdl_module.body):
+ # Change all arith.addi with index types to arith.muli.
+ @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+ def pat():
+ # Match arith.addi with index types.
+ i64_type = pdl.TypeOp(IntegerType.get_signless(64))
+ operand0 = pdl.OperandOp(i64_type)
+ operand1 = pdl.OperandOp(i64_type)
+ op0 = pdl.OperationOp(
+ name="arith.addi", args=[operand0, operand1], types=[i64_type]
+ )
+
+ # Replace the matched op with arith.muli.
+ @pdl.rewrite()
+ def rew():
+ newOp = pdl.OperationOp(
+ name="arith.muli", args=[operand0, operand1], types=[i64_type]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+ at run
+def testCustomPass():
+ with Context():
+ pdl_module = make_pdl_module()
+ frozen = PDLModule(pdl_module).freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ def custom_pass_1(op):
+ print("hello from pass 1!!!", file=sys.stderr)
+
+ class CustomPass2:
+ def __call__(self, m):
+ apply_patterns_and_fold_greedily(m, frozen)
+
+ custom_pass_2 = CustomPass2()
+
+ pm = PassManager("any")
+ pm.enable_ir_printing()
+
+ # CHECK: hello from pass 1!!!
+ # CHECK-LABEL: Dump After custom_pass_1
+ pm.add(custom_pass_1)
+ # CHECK-LABEL: Dump After CustomPass2
+ # CHECK: arith.muli
+ pm.add(custom_pass_2, "CustomPass2")
+ # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+ # CHECK: llvm.mul
+ pm.add("convert-arith-to-llvm")
+ pm.run(module)
More information about the Mlir-commits
mailing list