[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined passes in MLIR (PR #156000)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 29 10:10:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

**STATUS: This PR is work-in-progress now :)**

It tries to close #<!-- -->155996.

This PR exports a class `mlir.passmanager.Pass` for Python-side to use for defining new MLIR passes.

This is a simple example of a Python-defined pass.
```python
from mlir.passmanager import Pass, PassManager

class DemoPass(Pass):
  def run(op):
    # do something with op
    pass

pm = PassManager('any', ctx)
pm.add(DemoPass())
pm.run(..)
```

TODO list:
- [x] tests for this change
- [x] interop with PDL rewriting
- [x] support to clone passes
- [x] use Python-native ref-count for lifetime

---
Full diff: https://github.com/llvm/llvm-project/pull/156000.diff


7 Files Affected:

- (modified) mlir/include/mlir-c/Rewrite.h (+4) 
- (modified) mlir/lib/Bindings/Python/MainModule.cpp (+1) 
- (modified) mlir/lib/Bindings/Python/Pass.cpp (+103) 
- (modified) mlir/lib/Bindings/Python/Pass.h (+1) 
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+21-10) 
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+7) 
- (added) mlir/test/python/pass.py (+79) 


``````````diff
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..21ae236d6f73f 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
+    MlirOperation op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..590e862a8d358 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passModule);
+  populatePassSubmodule(passModule);
 }
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..3fd27fed34587 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,7 +11,9 @@
 #include "IRModule.h"
 #include "mlir-c/Pass.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -20,6 +22,81 @@ using namespace mlir::python;
 
 namespace {
 
+// A base class for defining passes in Python
+// Users are expected to subclass this and implement the `run` method, e.g.
+// ```
+// class MyPass(mlir.passmanager.Pass):
+//   def __init__(self):
+//     super().__init__("MyPass", ..)
+//     # other init stuff..
+//   def run(self, operation):
+//     # do something with operation..
+//     pass
+// ```
+class PyPassBase {
+public:
+  PyPassBase(std::string name, std::string argument, std::string description,
+             std::string opName)
+      : name(std::move(name)), argument(std::move(argument)),
+        description(std::move(description)), opName(std::move(opName)) {
+    callbacks.construct = [](void *obj) {};
+    callbacks.destruct = [](void *obj) {
+      nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+    };
+    callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
+      auto handle = nb::handle(static_cast<PyObject *>(obj));
+      nb::cast<PyPassBase *>(handle)->run(op);
+    };
+    callbacks.clone = [](void *obj) -> void * {
+      nb::object copy = nb::module_::import_("copy");
+      nb::object deepcopy = copy.attr("deepcopy");
+      return deepcopy(obj).release().ptr();
+    };
+    callbacks.initialize = nullptr;
+  }
+
+  // this method should be overridden by subclasses in Python.
+  virtual void run(MlirOperation op) = 0;
+
+  virtual ~PyPassBase() = default;
+
+  // Make an MlirPass instance on-the-fly that wraps this object.
+  // Note that passmanager will take the ownership of the returned
+  // object and release it when appropriate.
+  MlirPass make() {
+    auto *obj = nb::find(this).release().ptr();
+    return mlirCreateExternalPass(
+        mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
+        mlirStringRefCreate(argument.data(), argument.length()),
+        mlirStringRefCreate(description.data(), description.length()),
+        mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
+        callbacks, obj);
+  }
+
+  const std::string &getName() const { return name; }
+  const std::string &getArgument() const { return argument; }
+  const std::string &getDescription() const { return description; }
+  const std::string &getOpName() const { return opName; }
+
+private:
+  MlirExternalPassCallbacks callbacks;
+
+  std::string name;
+  std::string argument;
+  std::string description;
+  std::string opName;
+};
+
+// A trampoline class upon PyPassBase.
+// Refer to
+// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
+class PyPass : PyPassBase {
+public:
+  NB_TRAMPOLINE(PyPassBase, 1);
+
+  void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
+};
+
 /// Owning Wrapper around a PassManager.
 class PyPassManager {
 public:
@@ -52,6 +129,26 @@ class PyPassManager {
 
 } // namespace
 
+void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the Python-defined Pass interface
+  //----------------------------------------------------------------------------
+  nb::class_<PyPassBase, PyPass>(m, "Pass")
+      .def(nb::init<std::string, std::string, std::string, std::string>(),
+           "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
+           "op_name"_a = "", "Create a new Pass.")
+      .def("run", &PyPassBase::run, "operation"_a,
+           "Run the pass on the provided operation.")
+      .def_prop_ro("name",
+                   [](const PyPassBase &self) { return self.getName(); })
+      .def_prop_ro("argument",
+                   [](const PyPassBase &self) { return self.getArgument(); })
+      .def_prop_ro("description",
+                   [](const PyPassBase &self) { return self.getDescription(); })
+      .def_prop_ro("op_name",
+                   [](const PyPassBase &self) { return self.getOpName(); });
+}
+
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
@@ -157,6 +254,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
+      .def(
+          "add",
+          [](PyPassManager &passManager, PyPassBase &pass) {
+            mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
+          },
+          "pass"_a, "Add a python-defined pass to the pass manager.")
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op,
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..ba3fbb707fed7 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -15,6 +15,7 @@ namespace mlir {
 namespace python {
 
 void populatePassManagerSubmodule(nanobind::module_ &m);
+void populatePassSubmodule(nanobind::module_ &m);
 
 } // namespace python
 } // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..675bd685ec2db 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
-      "apply_patterns_and_fold_greedily",
-      [](MlirModule module, MlirFrozenRewritePatternSet set) {
-        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
-        if (mlirLogicalResultIsFailure(status))
-          // FIXME: Not sure this is the right error to throw here.
-          throw nb::value_error("pattern application failed to converge");
-      },
-      "module"_a, "set"_a,
-      "Applys the given patterns to the given module greedily while folding "
-      "results.");
+       "apply_patterns_and_fold_greedily",
+       [](MlirModule module, MlirFrozenRewritePatternSet set) {
+         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+         if (mlirLogicalResultIsFailure(status))
+           // FIXME: Not sure this is the right error to throw here.
+           throw nb::value_error("pattern application failed to converge");
+       },
+       "module"_a, "set"_a,
+       "Applys the given patterns to the given module greedily while folding "
+       "results.")
+      .def(
+          "apply_patterns_and_fold_greedily_for_op",
+          [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+            auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
+            if (mlirLogicalResultIsFailure(status))
+              // FIXME: Not sure this is the right error to throw here.
+              throw nb::value_error("pattern application failed to converge");
+          },
+          "op"_a, "set"_a,
+          "Applys the given patterns to the given op greedily while folding "
+          "results.");
 }
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..d606445cfad31 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
+                                      MlirFrozenRewritePatternSet patterns,
+                                      MlirGreedyRewriteDriverConfig) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
new file mode 100644
index 0000000000000..1c338e4bc6b49
--- /dev/null
+++ b/mlir/test/python/pass.py
@@ -0,0 +1,79 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def log(*args):
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
+
+def run(f):
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
+
+def make_pdl_module():
+    with Location.unknown():
+        pdl_module = Module.create()
+        with InsertionPoint(pdl_module.body):
+            # Change all arith.addi with index types to arith.muli.
+            @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+            def pat():
+                # Match arith.addi with index types.
+                index_type = pdl.TypeOp(IndexType.get())
+                operand0 = pdl.OperandOp(index_type)
+                operand1 = pdl.OperandOp(index_type)
+                op0 = pdl.OperationOp(
+                    name="arith.addi", args=[operand0, operand1], types=[index_type]
+                )
+
+                # Replace the matched op with arith.muli.
+                @pdl.rewrite()
+                def rew():
+                    newOp = pdl.OperationOp(
+                        name="arith.muli", args=[operand0, operand1], types=[index_type]
+                    )
+                    pdl.ReplaceOp(op0, with_op=newOp)
+
+        return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+ at run
+def testCustomPass():
+    with Context():
+        pdl_module = make_pdl_module()
+
+        class CustomPass(Pass):
+            def __init__(self):
+                super().__init__("CustomPass", op_name="builtin.module")
+
+            def run(self, m):
+                frozen = PDLModule(pdl_module).freeze()
+                apply_patterns_and_fold_greedily_for_op(m, frozen)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: index, %b: index) -> index {
+                %sum = arith.addi %a, %b : index
+                return %sum : index
+              }
+            }
+        """
+        )
+
+        # CHECK-LABEL: Dump After CustomPass
+        # CHECK: arith.muli
+        pm = PassManager("any")
+        pm.enable_ir_printing()
+        pm.add(CustomPass())
+        pm.run(module)

``````````

</details>


https://github.com/llvm/llvm-project/pull/156000


More information about the Mlir-commits mailing list