[Mlir-commits] [mlir] 7d04e37 - [MLIR][Python] Support Python-defined passes in MLIR (#156000)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Sep  8 18:01:27 PDT 2025
    
    
  
Author: Twice
Date: 2025-09-08T18:01:23-07:00
New Revision: 7d04e3790483f8e2f168ec538682bc31d3011a0b
URL: https://github.com/llvm/llvm-project/commit/7d04e3790483f8e2f168ec538682bc31d3011a0b
DIFF: https://github.com/llvm/llvm-project/commit/7d04e3790483f8e2f168ec538682bc31d3011a0b.diff
LOG: [MLIR][Python] Support Python-defined passes in MLIR (#156000)
It closes #155996.
This PR added a method `add(callable, ..)` to
`mlir.passmanager.PassManager` to accept a callable object for defining
passes in the Python side.
This is a simple example of a Python-defined pass.
```python
from mlir.passmanager import PassManager
def demo_pass_1(op):
    # do something with op
    pass
class DemoPass:
    def __init__(self, ...):
        pass
    def __call__(op):
        # do something
        pass
demo_pass_2 = DemoPass(..)
pm = PassManager('any', ctx)
pm.add(demo_pass_1)
pm.add(demo_pass_2)
pm.add("registered-passes")
pm.run(..)
```
---------
Co-authored-by: cnb.bsD2OPwAgEA <QejD2DJ2eEahUVy6Zg0aZI+cnb.bsD2OPwAgEA at noreply.cnb.cool>
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
Added: 
    mlir/test/python/python_pass.py
Modified: 
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/lib/CAPI/IR/Pass.cpp
Removed: 
    
################################################################################
diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..d7282b3d6f713 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
   populateRewriteSubmodule(rewriteModule);
 
   // Define and populate PassManager submodule.
-  auto passModule =
+  auto passManagerModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
-  populatePassManagerSubmodule(passModule);
+  populatePassManagerSubmodule(passManagerModule);
 }
diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 72bf8ed8f856f..6ee85e8a31492 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -10,8 +10,10 @@
 
 #include "IRModule.h"
 #include "mlir-c/Pass.h"
+// clang-format off
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
+      .def(
+          "add",
+          [](PyPassManager &passManager, const nb::callable &run,
+             std::optional<std::string> &name, const std::string &argument,
+             const std::string &description, const std::string &opName) {
+            if (!name.has_value()) {
+              name = nb::cast<std::string>(
+                  nb::borrow<nb::str>(run.attr("__name__")));
+            }
+            MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+            MlirTypeID passID =
+                mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+            MlirExternalPassCallbacks callbacks;
+            callbacks.construct = [](void *obj) {
+              (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+            };
+            callbacks.destruct = [](void *obj) {
+              (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+            };
+            callbacks.initialize = nullptr;
+            callbacks.clone = [](void *) -> void * {
+              throw std::runtime_error("Cloning Python passes not supported");
+            };
+            callbacks.run = [](MlirOperation op, MlirExternalPass,
+                               void *userData) {
+              nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+            };
+            auto externalPass = mlirCreateExternalPass(
+                passID, mlirStringRefCreate(name->data(), name->length()),
+                mlirStringRefCreate(argument.data(), argument.length()),
+                mlirStringRefCreate(description.data(), description.length()),
+                mlirStringRefCreate(opName.data(), opName.size()),
+                /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
+                callbacks, /*userData*/ run.ptr());
+            mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
+          },
+          "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
+          "description"_a.none() = "", "op_name"_a.none() = "",
+          "Add a python-defined pass to the pass manager.")
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op) {
diff  --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 3c499c3e4974d..b0a6ec1ace3cc 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
       : Pass(passID, opName), id(passID), name(name), argument(argument),
         description(description), dependentDialects(dependentDialects),
         callbacks(callbacks), userData(userData) {
-    callbacks.construct(userData);
+    if (callbacks.construct)
+      callbacks.construct(userData);
   }
 
-  ~ExternalPass() override { callbacks.destruct(userData); }
+  ~ExternalPass() override {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
 
   StringRef getName() const override { return name; }
   StringRef getArgument() const override { return argument; }
diff  --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
new file mode 100644
index 0000000000000..c94f96e20966f
--- /dev/null
+++ b/mlir/test/python/python_pass.py
@@ -0,0 +1,88 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def log(*args):
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
+
+def run(f):
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
+
+def make_pdl_module():
+    with Location.unknown():
+        pdl_module = Module.create()
+        with InsertionPoint(pdl_module.body):
+            # Change all arith.addi with index types to arith.muli.
+            @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+            def pat():
+                # Match arith.addi with index types.
+                i64_type = pdl.TypeOp(IntegerType.get_signless(64))
+                operand0 = pdl.OperandOp(i64_type)
+                operand1 = pdl.OperandOp(i64_type)
+                op0 = pdl.OperationOp(
+                    name="arith.addi", args=[operand0, operand1], types=[i64_type]
+                )
+
+                # Replace the matched op with arith.muli.
+                @pdl.rewrite()
+                def rew():
+                    newOp = pdl.OperationOp(
+                        name="arith.muli", args=[operand0, operand1], types=[i64_type]
+                    )
+                    pdl.ReplaceOp(op0, with_op=newOp)
+
+        return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+ at run
+def testCustomPass():
+    with Context():
+        pdl_module = make_pdl_module()
+        frozen = PDLModule(pdl_module).freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+        """
+        )
+
+        def custom_pass_1(op):
+            print("hello from pass 1!!!", file=sys.stderr)
+
+        class CustomPass2:
+            def __call__(self, m):
+                apply_patterns_and_fold_greedily(m, frozen)
+
+        custom_pass_2 = CustomPass2()
+
+        pm = PassManager("any")
+        pm.enable_ir_printing()
+
+        # CHECK: hello from pass 1!!!
+        # CHECK-LABEL: Dump After custom_pass_1
+        pm.add(custom_pass_1)
+        # CHECK-LABEL: Dump After CustomPass2
+        # CHECK: arith.muli
+        pm.add(custom_pass_2, "CustomPass2")
+        # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+        # CHECK: llvm.mul
+        pm.add("convert-arith-to-llvm")
+        pm.run(module)
        
    
    
More information about the Mlir-commits
mailing list