[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined passes in MLIR (PR #156000)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 8 00:58:32 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/156000

>From 8386c87c431585c4412a52d07287b7423e34f602 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 29 Aug 2025 17:43:04 +0800
Subject: [PATCH 01/15] [MLIR][Python] Support Python-defined passes in MLIR

---
 mlir/lib/Bindings/Python/MainModule.cpp |  1 +
 mlir/lib/Bindings/Python/Pass.cpp       | 80 ++++++++++++++++++++++++-
 mlir/lib/Bindings/Python/Pass.h         |  1 +
 3 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..590e862a8d358 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passModule);
+  populatePassSubmodule(passModule);
 }
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..4aa93df938295 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -9,9 +9,11 @@
 #include "Pass.h"
 
 #include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/Pass.h"
 #include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
+#include "llvm/Support/ErrorHandling.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -20,6 +22,63 @@ using namespace mlir::python;
 
 namespace {
 
+// A base class for defining passes in Python
+// Users are expected to subclass this and implement the `run` method, e.g.
+// ```
+// class MyPass(mlir.passmanager.Pass):
+//   def run(self, operation):
+//     # do something with operation
+//     pass
+// ```
+class PyPassBase {
+public:
+  PyPassBase() : callbacks{} {
+    callbacks.construct = [](void *) {};
+    callbacks.destruct = [](void *) {};
+    callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
+      static_cast<PyPassBase *>(obj)->run(op);
+    };
+    // TODO: currently we don't support pass cloning in python
+    // due to lifetime management issues.
+    callbacks.clone = [](void *obj) -> void * {
+      // since the caller here should be MLIR C++ code,
+      // we need to avoid using exceptions like throw py::value_error(...).
+      llvm_unreachable("cloning of python-defined passes is not supported");
+    };
+  }
+
+  // this method should be overridden by subclasses in Python.
+  virtual void run(MlirOperation op) = 0;
+
+  virtual ~PyPassBase() = default;
+
+  // Make an MlirPass instance on-the-fly that wraps this object.
+  // Note that passmanager will take the ownership of the returned
+  // object and release it when appropriate.
+  // Also, `*this` must remain alive as long as the returned object is alive.
+  MlirPass make() {
+    return mlirCreateExternalPass(
+        mlirTypeIDCreate(this),
+        mlirStringRefCreateFromCString("python-example-pass"),
+        mlirStringRefCreateFromCString(""),
+        mlirStringRefCreateFromCString("Python Example Pass"),
+        mlirStringRefCreateFromCString(""), 0, nullptr, callbacks, this);
+  }
+
+private:
+  MlirExternalPassCallbacks callbacks;
+};
+
+// A trampoline class upon PyPassBase.
+// Refer to
+// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
+class PyPass : PyPassBase {
+public:
+  NB_TRAMPOLINE(PyPassBase, 1);
+
+  void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
+};
+
 /// Owning Wrapper around a PassManager.
 class PyPassManager {
 public:
@@ -52,6 +111,16 @@ class PyPassManager {
 
 } // namespace
 
+void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the Python-defined Pass interface
+  //----------------------------------------------------------------------------
+  nb::class_<PyPassBase, PyPass>(m, "Pass")
+      .def(nb::init<>(), "Create a new Pass.")
+      .def("run", &PyPassBase::run, "operation"_a,
+           "Run the pass on the provided operation.");
+}
+
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
@@ -157,6 +226,15 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
+      .def(
+          "add",
+          [](PyPassManager &passManager, PyPassBase &pass) {
+            mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
+          },
+          "pass"_a, "Add a python-defined pass to the pass manager.",
+          // NOTE that we should keep the pass object alive as long as the
+          // passManager to prevent dangling objects.
+          nb::keep_alive<1, 2>())
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op,
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..ba3fbb707fed7 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -15,6 +15,7 @@ namespace mlir {
 namespace python {
 
 void populatePassManagerSubmodule(nanobind::module_ &m);
+void populatePassSubmodule(nanobind::module_ &m);
 
 } // namespace python
 } // namespace mlir

>From cb826212272d3276350019d40703113a2f30e983 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 29 Aug 2025 18:39:17 +0800
Subject: [PATCH 02/15] add ctor with args for Pass

---
 mlir/lib/Bindings/Python/Pass.cpp | 40 ++++++++++++++++++++++++-------
 1 file changed, 32 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 4aa93df938295..898d3c096c1d8 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,6 +11,7 @@
 #include "IRModule.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/Pass.h"
+#include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "nanobind/trampoline.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -32,7 +33,10 @@ namespace {
 // ```
 class PyPassBase {
 public:
-  PyPassBase() : callbacks{} {
+  PyPassBase(std::string name, std::string argument, std::string description,
+             std::string opName)
+      : callbacks{}, name(std::move(name)), argument(std::move(argument)),
+        description(std::move(description)), opName(std::move(opName)) {
     callbacks.construct = [](void *) {};
     callbacks.destruct = [](void *) {};
     callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
@@ -58,15 +62,25 @@ class PyPassBase {
   // Also, `*this` must remain alive as long as the returned object is alive.
   MlirPass make() {
     return mlirCreateExternalPass(
-        mlirTypeIDCreate(this),
-        mlirStringRefCreateFromCString("python-example-pass"),
-        mlirStringRefCreateFromCString(""),
-        mlirStringRefCreateFromCString("Python Example Pass"),
-        mlirStringRefCreateFromCString(""), 0, nullptr, callbacks, this);
+        mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
+        mlirStringRefCreate(argument.data(), argument.length()),
+        mlirStringRefCreate(description.data(), description.length()),
+        mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
+        callbacks, this);
   }
 
+  const std::string &getName() const { return name; }
+  const std::string &getArgument() const { return argument; }
+  const std::string &getDescription() const { return description; }
+  const std::string &getOpName() const { return opName; }
+
 private:
   MlirExternalPassCallbacks callbacks;
+
+  std::string name;
+  std::string argument;
+  std::string description;
+  std::string opName;
 };
 
 // A trampoline class upon PyPassBase.
@@ -116,9 +130,19 @@ void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
   // Mapping of the Python-defined Pass interface
   //----------------------------------------------------------------------------
   nb::class_<PyPassBase, PyPass>(m, "Pass")
-      .def(nb::init<>(), "Create a new Pass.")
+      .def(nb::init<std::string, std::string, std::string, std::string>(),
+           "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
+           "op_name"_a = "", "Create a new Pass.")
       .def("run", &PyPassBase::run, "operation"_a,
-           "Run the pass on the provided operation.");
+           "Run the pass on the provided operation.")
+      .def_prop_ro("name",
+                   [](const PyPassBase &self) { return self.getName(); })
+      .def_prop_ro("argument",
+                   [](const PyPassBase &self) { return self.getArgument(); })
+      .def_prop_ro("description",
+                   [](const PyPassBase &self) { return self.getDescription(); })
+      .def_prop_ro("op_name",
+                   [](const PyPassBase &self) { return self.getOpName(); });
 }
 
 /// Create the `mlir.passmanager` here.

>From 7556ca2ad0deb938765034cda9c49edd3c7c5975 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 29 Aug 2025 22:26:40 +0800
Subject: [PATCH 03/15] fix lifetime issue

---
 mlir/lib/Bindings/Python/Pass.cpp | 28 ++++++++++++++--------------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 898d3c096c1d8..babb8a723ca04 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -35,20 +35,22 @@ class PyPassBase {
 public:
   PyPassBase(std::string name, std::string argument, std::string description,
              std::string opName)
-      : callbacks{}, name(std::move(name)), argument(std::move(argument)),
+      : name(std::move(name)), argument(std::move(argument)),
         description(std::move(description)), opName(std::move(opName)) {
-    callbacks.construct = [](void *) {};
-    callbacks.destruct = [](void *) {};
+    callbacks.construct = [](void *obj) {};
+    callbacks.destruct = [](void *obj) {
+      nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+    };
     callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
-      static_cast<PyPassBase *>(obj)->run(op);
+      auto handle = nb::handle(static_cast<PyObject *>(obj));
+      nb::cast<PyPassBase *>(handle)->run(op);
     };
-    // TODO: currently we don't support pass cloning in python
-    // due to lifetime management issues.
     callbacks.clone = [](void *obj) -> void * {
-      // since the caller here should be MLIR C++ code,
-      // we need to avoid using exceptions like throw py::value_error(...).
-      llvm_unreachable("cloning of python-defined passes is not supported");
+      nb::object copy = nb::module_::import_("copy");
+      nb::object deepcopy = copy.attr("deepcopy");
+      return deepcopy(obj).release().ptr();
     };
+    callbacks.initialize = nullptr;
   }
 
   // this method should be overridden by subclasses in Python.
@@ -61,12 +63,13 @@ class PyPassBase {
   // object and release it when appropriate.
   // Also, `*this` must remain alive as long as the returned object is alive.
   MlirPass make() {
+    auto *obj = nb::find(this).release().ptr();
     return mlirCreateExternalPass(
         mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
         mlirStringRefCreate(argument.data(), argument.length()),
         mlirStringRefCreate(description.data(), description.length()),
         mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
-        callbacks, this);
+        callbacks, obj);
   }
 
   const std::string &getName() const { return name; }
@@ -255,10 +258,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           [](PyPassManager &passManager, PyPassBase &pass) {
             mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
           },
-          "pass"_a, "Add a python-defined pass to the pass manager.",
-          // NOTE that we should keep the pass object alive as long as the
-          // passManager to prevent dangling objects.
-          nb::keep_alive<1, 2>())
+          "pass"_a, "Add a python-defined pass to the pass manager.")
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op,

>From d5055a79a683a7cd71533db7bed537d6028f1950 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 00:20:59 +0800
Subject: [PATCH 04/15] add test case

---
 mlir/include/mlir-c/Rewrite.h        |  4 ++
 mlir/lib/Bindings/Python/Pass.cpp    |  4 +-
 mlir/lib/Bindings/Python/Rewrite.cpp | 31 ++++++++----
 mlir/lib/CAPI/Transforms/Rewrite.cpp |  7 +++
 mlir/test/python/pass.py             | 73 ++++++++++++++++++++++++++++
 5 files changed, 106 insertions(+), 13 deletions(-)
 create mode 100644 mlir/test/python/pass.py

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..21ae236d6f73f 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
+    MlirOperation op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index babb8a723ca04..920d604d24680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -9,12 +9,10 @@
 #include "Pass.h"
 
 #include "IRModule.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/Pass.h"
-#include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "nanobind/trampoline.h"
-#include "llvm/Support/ErrorHandling.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 
 namespace nb = nanobind;
 using namespace nb::literals;
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..675bd685ec2db 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
-      "apply_patterns_and_fold_greedily",
-      [](MlirModule module, MlirFrozenRewritePatternSet set) {
-        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
-        if (mlirLogicalResultIsFailure(status))
-          // FIXME: Not sure this is the right error to throw here.
-          throw nb::value_error("pattern application failed to converge");
-      },
-      "module"_a, "set"_a,
-      "Applys the given patterns to the given module greedily while folding "
-      "results.");
+       "apply_patterns_and_fold_greedily",
+       [](MlirModule module, MlirFrozenRewritePatternSet set) {
+         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+         if (mlirLogicalResultIsFailure(status))
+           // FIXME: Not sure this is the right error to throw here.
+           throw nb::value_error("pattern application failed to converge");
+       },
+       "module"_a, "set"_a,
+       "Applys the given patterns to the given module greedily while folding "
+       "results.")
+      .def(
+          "apply_patterns_and_fold_greedily_for_op",
+          [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+            auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
+            if (mlirLogicalResultIsFailure(status))
+              // FIXME: Not sure this is the right error to throw here.
+              throw nb::value_error("pattern application failed to converge");
+          },
+          "op"_a, "set"_a,
+          "Applys the given patterns to the given op greedily while folding "
+          "results.");
 }
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..d606445cfad31 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
+                                      MlirFrozenRewritePatternSet patterns,
+                                      MlirGreedyRewriteDriverConfig) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
new file mode 100644
index 0000000000000..32943caec19a4
--- /dev/null
+++ b/mlir/test/python/pass.py
@@ -0,0 +1,73 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+def log(*args):
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
+
+def run(f):
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
+def make_pdl_module():
+    with Location.unknown():
+        pdl_module = Module.create()
+        with InsertionPoint(pdl_module.body):
+            # Change all arith.addi with index types to arith.muli.
+            @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+            def pat():
+                # Match arith.addi with index types.
+                index_type = pdl.TypeOp(IndexType.get())
+                operand0 = pdl.OperandOp(index_type)
+                operand1 = pdl.OperandOp(index_type)
+                op0 = pdl.OperationOp(
+                    name="arith.addi", args=[operand0, operand1], types=[index_type]
+                )
+
+                # Replace the matched op with arith.muli.
+                @pdl.rewrite()
+                def rew():
+                    newOp = pdl.OperationOp(
+                        name="arith.muli", args=[operand0, operand1], types=[index_type]
+                    )
+                    pdl.ReplaceOp(op0, with_op=newOp)
+
+        return pdl_module
+
+# CHECK-LABEL: TEST: testCustomPass
+ at run
+def testCustomPass():
+    with Context():
+        pdl_module = make_pdl_module()
+
+        class CustomPass(Pass):
+            def __init__(self):
+                super().__init__("CustomPass", op_name="builtin.module")
+            def run(self, m):
+                frozen = PDLModule(pdl_module).freeze()
+                apply_patterns_and_fold_greedily_for_op(m, frozen)
+
+        module = ModuleOp.parse(r"""
+            module {
+              func.func @add(%a: index, %b: index) -> index {
+                %sum = arith.addi %a, %b : index
+                return %sum : index
+              }
+            }
+        """)
+
+        # CHECK-LABEL: Dump After CustomPass
+        # CHECK: arith.muli
+        pm = PassManager('any')
+        pm.enable_ir_printing()
+        pm.add(CustomPass())
+        pm.run(module)

>From 6d2f4720b5c724a5a65d257e05b60ea7adae9569 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 00:31:05 +0800
Subject: [PATCH 05/15] fix header

---
 mlir/lib/Bindings/Python/Pass.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 920d604d24680..cda1f7af243d3 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,8 +11,8 @@
 #include "IRModule.h"
 #include "mlir-c/Pass.h"
 #include "mlir/Bindings/Python/Nanobind.h"
-#include "nanobind/trampoline.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;

>From 1a98ae84df54f1d441a063dd267ee70b574d3b22 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 00:31:46 +0800
Subject: [PATCH 06/15] format

---
 mlir/test/python/pass.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
index 32943caec19a4..4389e3bf38686 100644
--- a/mlir/test/python/pass.py
+++ b/mlir/test/python/pass.py
@@ -7,6 +7,7 @@
 from mlir.dialects import pdl
 from mlir.rewrite import *
 
+
 def log(*args):
     print(*args, file=sys.stderr)
     sys.stderr.flush()
@@ -18,6 +19,7 @@ def run(f):
     gc.collect()
     assert Context._get_live_count() == 0
 
+
 def make_pdl_module():
     with Location.unknown():
         pdl_module = Module.create()
@@ -43,6 +45,7 @@ def rew():
 
         return pdl_module
 
+
 # CHECK-LABEL: TEST: testCustomPass
 @run
 def testCustomPass():

>From 751fe84f22fbfeba3cffea0e94766b5cdb82b204 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 00:58:52 +0800
Subject: [PATCH 07/15] format

---
 mlir/test/python/pass.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
index 4389e3bf38686..1c338e4bc6b49 100644
--- a/mlir/test/python/pass.py
+++ b/mlir/test/python/pass.py
@@ -55,22 +55,25 @@ def testCustomPass():
         class CustomPass(Pass):
             def __init__(self):
                 super().__init__("CustomPass", op_name="builtin.module")
+
             def run(self, m):
                 frozen = PDLModule(pdl_module).freeze()
                 apply_patterns_and_fold_greedily_for_op(m, frozen)
 
-        module = ModuleOp.parse(r"""
+        module = ModuleOp.parse(
+            r"""
             module {
               func.func @add(%a: index, %b: index) -> index {
                 %sum = arith.addi %a, %b : index
                 return %sum : index
               }
             }
-        """)
+        """
+        )
 
         # CHECK-LABEL: Dump After CustomPass
         # CHECK: arith.muli
-        pm = PassManager('any')
+        pm = PassManager("any")
         pm.enable_ir_printing()
         pm.add(CustomPass())
         pm.run(module)

>From 7966ddde2644ad195a63940237702715ff166ad5 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 01:02:45 +0800
Subject: [PATCH 08/15] format

---
 mlir/lib/Bindings/Python/Pass.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cda1f7af243d3..e482d1eb26d92 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -11,6 +11,7 @@
 #include "IRModule.h"
 #include "mlir-c/Pass.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "nanobind/trampoline.h"
 

>From 6a9ec66730ac0666a88b5ec9aacf6162d8ceb8db Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 01:08:54 +0800
Subject: [PATCH 09/15] remove useless comment

---
 mlir/lib/Bindings/Python/Pass.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index e482d1eb26d92..3fd27fed34587 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -26,8 +26,11 @@ namespace {
 // Users are expected to subclass this and implement the `run` method, e.g.
 // ```
 // class MyPass(mlir.passmanager.Pass):
+//   def __init__(self):
+//     super().__init__("MyPass", ..)
+//     # other init stuff..
 //   def run(self, operation):
-//     # do something with operation
+//     # do something with operation..
 //     pass
 // ```
 class PyPassBase {
@@ -60,7 +63,6 @@ class PyPassBase {
   // Make an MlirPass instance on-the-fly that wraps this object.
   // Note that passmanager will take the ownership of the returned
   // object and release it when appropriate.
-  // Also, `*this` must remain alive as long as the returned object is alive.
   MlirPass make() {
     auto *obj = nb::find(this).release().ptr();
     return mlirCreateExternalPass(

>From 2965a9e06be820d91f8ebcfcc43ad1ea51ceac4d Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 30 Aug 2025 13:31:50 +0800
Subject: [PATCH 10/15] improve test

---
 mlir/test/python/pass.py | 24 ++++++++++++++----------
 1 file changed, 14 insertions(+), 10 deletions(-)

diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
index 1c338e4bc6b49..47d11ff1c1cc8 100644
--- a/mlir/test/python/pass.py
+++ b/mlir/test/python/pass.py
@@ -28,18 +28,18 @@ def make_pdl_module():
             @pdl.pattern(benefit=1, sym_name="addi_to_mul")
             def pat():
                 # Match arith.addi with index types.
-                index_type = pdl.TypeOp(IndexType.get())
-                operand0 = pdl.OperandOp(index_type)
-                operand1 = pdl.OperandOp(index_type)
+                i64_type = pdl.TypeOp(IntegerType.get_signless(64))
+                operand0 = pdl.OperandOp(i64_type)
+                operand1 = pdl.OperandOp(i64_type)
                 op0 = pdl.OperationOp(
-                    name="arith.addi", args=[operand0, operand1], types=[index_type]
+                    name="arith.addi", args=[operand0, operand1], types=[i64_type]
                 )
 
                 # Replace the matched op with arith.muli.
                 @pdl.rewrite()
                 def rew():
                     newOp = pdl.OperationOp(
-                        name="arith.muli", args=[operand0, operand1], types=[index_type]
+                        name="arith.muli", args=[operand0, operand1], types=[i64_type]
                     )
                     pdl.ReplaceOp(op0, with_op=newOp)
 
@@ -63,17 +63,21 @@ def run(self, m):
         module = ModuleOp.parse(
             r"""
             module {
-              func.func @add(%a: index, %b: index) -> index {
-                %sum = arith.addi %a, %b : index
-                return %sum : index
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
               }
             }
         """
         )
 
-        # CHECK-LABEL: Dump After CustomPass
-        # CHECK: arith.muli
         pm = PassManager("any")
         pm.enable_ir_printing()
+
+        # CHECK-LABEL: Dump After CustomPass
+        # CHECK: arith.muli
         pm.add(CustomPass())
+        # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+        # CHECK: llvm.mul
+        pm.add("convert-arith-to-llvm")
         pm.run(module)

>From ca80408c0356806de30bd5e85dcd48d9d3d24663 Mon Sep 17 00:00:00 2001
From: "cnb.bsD2OPwAgEA"
 <QejD2DJ2eEahUVy6Zg0aZI+cnb.bsD2OPwAgEA at noreply.cnb.cool>
Date: Mon, 1 Sep 2025 22:27:49 +0800
Subject: [PATCH 11/15] rename mlirApplyPatternsAndFoldGreedilyForOp with
 mlirApplyPatternsAndFoldGreedilyWithOp

---
 mlir/include/mlir-c/Rewrite.h        | 2 +-
 mlir/lib/Bindings/Python/Rewrite.cpp | 4 ++--
 mlir/lib/CAPI/Transforms/Rewrite.cpp | 6 +++---
 mlir/test/python/pass.py             | 2 +-
 4 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 21ae236d6f73f..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,7 +301,7 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
-MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
     MlirOperation op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
 
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 675bd685ec2db..e764535c1a4c0 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -110,9 +110,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
        "Applys the given patterns to the given module greedily while folding "
        "results.")
       .def(
-          "apply_patterns_and_fold_greedily_for_op",
+          "apply_patterns_and_fold_greedily_with_op",
           [](MlirOperation op, MlirFrozenRewritePatternSet set) {
-            auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
+            auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
             if (mlirLogicalResultIsFailure(status))
               // FIXME: Not sure this is the right error to throw here.
               throw nb::value_error("pattern application failed to converge");
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index d606445cfad31..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -295,9 +295,9 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
 }
 
 MlirLogicalResult
-mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
-                                      MlirFrozenRewritePatternSet patterns,
-                                      MlirGreedyRewriteDriverConfig) {
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+                                       MlirFrozenRewritePatternSet patterns,
+                                       MlirGreedyRewriteDriverConfig) {
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
index 47d11ff1c1cc8..fb37cabde1ed1 100644
--- a/mlir/test/python/pass.py
+++ b/mlir/test/python/pass.py
@@ -58,7 +58,7 @@ def __init__(self):
 
             def run(self, m):
                 frozen = PDLModule(pdl_module).freeze()
-                apply_patterns_and_fold_greedily_for_op(m, frozen)
+                apply_patterns_and_fold_greedily_with_op(m, frozen)
 
         module = ModuleOp.parse(
             r"""

>From c8c2faeb61ffe22fd7b60e6669bde604b7931940 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 2 Sep 2025 14:26:44 +0800
Subject: [PATCH 12/15] add clang-format annotation for header ordering

---
 mlir/lib/Bindings/Python/Pass.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 3fd27fed34587..0ee26bb153760 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -10,9 +10,10 @@
 
 #include "IRModule.h"
 #include "mlir-c/Pass.h"
+// clang-format off
 #include "mlir/Bindings/Python/Nanobind.h"
-
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
 #include "nanobind/trampoline.h"
 
 namespace nb = nanobind;

>From 01e68c50bc6c12809a2790653f41d403debd1dde Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 2 Sep 2025 14:44:05 +0800
Subject: [PATCH 13/15] rename mlir.passmanager.Pass to mlir.passes.Pass

---
 mlir/lib/Bindings/Python/MainModule.cpp | 8 +++++---
 mlir/lib/Bindings/Python/Pass.cpp       | 2 +-
 mlir/python/CMakeLists.txt              | 1 +
 mlir/python/mlir/passes.py              | 5 +++++
 mlir/test/python/pass.py                | 3 ++-
 5 files changed, 14 insertions(+), 5 deletions(-)
 create mode 100644 mlir/python/mlir/passes.py

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 590e862a8d358..94604a567858a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -136,8 +136,10 @@ NB_MODULE(_mlir, m) {
   populateRewriteSubmodule(rewriteModule);
 
   // Define and populate PassManager submodule.
-  auto passModule =
+  auto passManagerModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
-  populatePassManagerSubmodule(passModule);
-  populatePassSubmodule(passModule);
+  populatePassManagerSubmodule(passManagerModule);
+  auto passesModule =
+      m.def_submodule("passes", "MLIR Pass Infrastructure Bindings");
+  populatePassSubmodule(passesModule);
 }
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 0ee26bb153760..73d6c556e5181 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -26,7 +26,7 @@ namespace {
 // A base class for defining passes in Python
 // Users are expected to subclass this and implement the `run` method, e.g.
 // ```
-// class MyPass(mlir.passmanager.Pass):
+// class MyPass(Pass):
 //   def __init__(self):
 //     super().__init__("MyPass", ..)
 //     # other init stuff..
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..fde53a4d64d1c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -20,6 +20,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
   SOURCES
     _mlir_libs/__init__.py
     ir.py
+    passes.py
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
diff --git a/mlir/python/mlir/passes.py b/mlir/python/mlir/passes.py
new file mode 100644
index 0000000000000..aab9d6b252bbc
--- /dev/null
+++ b/mlir/python/mlir/passes.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._mlir_libs._mlir.passes import *
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
index fb37cabde1ed1..0f6d818e850dc 100644
--- a/mlir/test/python/pass.py
+++ b/mlir/test/python/pass.py
@@ -3,6 +3,7 @@
 import gc, sys
 from mlir.ir import *
 from mlir.passmanager import *
+from mlir.passes import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import pdl
 from mlir.rewrite import *
@@ -51,13 +52,13 @@ def rew():
 def testCustomPass():
     with Context():
         pdl_module = make_pdl_module()
+        frozen = PDLModule(pdl_module).freeze()
 
         class CustomPass(Pass):
             def __init__(self):
                 super().__init__("CustomPass", op_name="builtin.module")
 
             def run(self, m):
-                frozen = PDLModule(pdl_module).freeze()
                 apply_patterns_and_fold_greedily_with_op(m, frozen)
 
         module = ModuleOp.parse(

>From e565ffbcbde97f8f7b09b935a3a546451be17e99 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 8 Sep 2025 15:49:17 +0800
Subject: [PATCH 14/15] Merge changes from #157369

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
 mlir/lib/Bindings/Python/MainModule.cpp      |   3 -
 mlir/lib/Bindings/Python/Pass.cpp            | 137 +++++--------------
 mlir/lib/Bindings/Python/Pass.h              |   1 -
 mlir/lib/CAPI/IR/Pass.cpp                    |   8 +-
 mlir/python/CMakeLists.txt                   |   1 -
 mlir/python/mlir/passes.py                   |   5 -
 mlir/test/python/{pass.py => python_pass.py} |  24 ++--
 7 files changed, 57 insertions(+), 122 deletions(-)
 delete mode 100644 mlir/python/mlir/passes.py
 rename mlir/test/python/{pass.py => python_pass.py} (83%)

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 94604a567858a..d7282b3d6f713 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,7 +139,4 @@ NB_MODULE(_mlir, m) {
   auto passManagerModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passManagerModule);
-  auto passesModule =
-      m.def_submodule("passes", "MLIR Pass Infrastructure Bindings");
-  populatePassSubmodule(passesModule);
 }
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 73d6c556e5181..a00f7a2fa41b5 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 // clang-format on
-#include "nanobind/trampoline.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -23,81 +22,6 @@ using namespace mlir::python;
 
 namespace {
 
-// A base class for defining passes in Python
-// Users are expected to subclass this and implement the `run` method, e.g.
-// ```
-// class MyPass(Pass):
-//   def __init__(self):
-//     super().__init__("MyPass", ..)
-//     # other init stuff..
-//   def run(self, operation):
-//     # do something with operation..
-//     pass
-// ```
-class PyPassBase {
-public:
-  PyPassBase(std::string name, std::string argument, std::string description,
-             std::string opName)
-      : name(std::move(name)), argument(std::move(argument)),
-        description(std::move(description)), opName(std::move(opName)) {
-    callbacks.construct = [](void *obj) {};
-    callbacks.destruct = [](void *obj) {
-      nb::handle(static_cast<PyObject *>(obj)).dec_ref();
-    };
-    callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
-      auto handle = nb::handle(static_cast<PyObject *>(obj));
-      nb::cast<PyPassBase *>(handle)->run(op);
-    };
-    callbacks.clone = [](void *obj) -> void * {
-      nb::object copy = nb::module_::import_("copy");
-      nb::object deepcopy = copy.attr("deepcopy");
-      return deepcopy(obj).release().ptr();
-    };
-    callbacks.initialize = nullptr;
-  }
-
-  // this method should be overridden by subclasses in Python.
-  virtual void run(MlirOperation op) = 0;
-
-  virtual ~PyPassBase() = default;
-
-  // Make an MlirPass instance on-the-fly that wraps this object.
-  // Note that passmanager will take the ownership of the returned
-  // object and release it when appropriate.
-  MlirPass make() {
-    auto *obj = nb::find(this).release().ptr();
-    return mlirCreateExternalPass(
-        mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
-        mlirStringRefCreate(argument.data(), argument.length()),
-        mlirStringRefCreate(description.data(), description.length()),
-        mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
-        callbacks, obj);
-  }
-
-  const std::string &getName() const { return name; }
-  const std::string &getArgument() const { return argument; }
-  const std::string &getDescription() const { return description; }
-  const std::string &getOpName() const { return opName; }
-
-private:
-  MlirExternalPassCallbacks callbacks;
-
-  std::string name;
-  std::string argument;
-  std::string description;
-  std::string opName;
-};
-
-// A trampoline class upon PyPassBase.
-// Refer to
-// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
-class PyPass : PyPassBase {
-public:
-  NB_TRAMPOLINE(PyPassBase, 1);
-
-  void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
-};
-
 /// Owning Wrapper around a PassManager.
 class PyPassManager {
 public:
@@ -130,26 +54,6 @@ class PyPassManager {
 
 } // namespace
 
-void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
-  //----------------------------------------------------------------------------
-  // Mapping of the Python-defined Pass interface
-  //----------------------------------------------------------------------------
-  nb::class_<PyPassBase, PyPass>(m, "Pass")
-      .def(nb::init<std::string, std::string, std::string, std::string>(),
-           "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
-           "op_name"_a = "", "Create a new Pass.")
-      .def("run", &PyPassBase::run, "operation"_a,
-           "Run the pass on the provided operation.")
-      .def_prop_ro("name",
-                   [](const PyPassBase &self) { return self.getName(); })
-      .def_prop_ro("argument",
-                   [](const PyPassBase &self) { return self.getArgument(); })
-      .def_prop_ro("description",
-                   [](const PyPassBase &self) { return self.getDescription(); })
-      .def_prop_ro("op_name",
-                   [](const PyPassBase &self) { return self.getOpName(); });
-}
-
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
@@ -256,11 +160,44 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
       .def(
-          "add",
-          [](PyPassManager &passManager, PyPassBase &pass) {
-            mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
+          "add_python_pass",
+          [](PyPassManager &passManager, const nb::callable &run,
+             std::optional<std::string> &name, const std::string &argument,
+             const std::string &description, const std::string &opName) {
+            if (!name.has_value()) {
+              name = nb::cast<std::string>(
+                  nb::borrow<nb::str>(run.attr("__name__")));
+            }
+            MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+            MlirTypeID passID =
+                mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+            MlirExternalPassCallbacks callbacks;
+            callbacks.construct = [](void *obj) {
+              (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+            };
+            callbacks.destruct = [](void *obj) {
+              (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+            };
+            callbacks.initialize = nullptr;
+            callbacks.clone = [](void *) -> void * {
+              throw std::runtime_error("Cloning Python passes not supported");
+            };
+            callbacks.run = [](MlirOperation op, MlirExternalPass,
+                               void *userData) {
+              nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+            };
+            auto externalPass = mlirCreateExternalPass(
+                passID, mlirStringRefCreate(name->data(), name->length()),
+                mlirStringRefCreate(argument.data(), argument.length()),
+                mlirStringRefCreate(description.data(), description.length()),
+                mlirStringRefCreate(opName.data(), opName.size()),
+                /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
+                callbacks, /*userData*/ run.ptr());
+            mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
           },
-          "pass"_a, "Add a python-defined pass to the pass manager.")
+          "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
+          "description"_a.none() = "", "op_name"_a.none() = "",
+          "Add a python-defined pass to the pass manager.")
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op,
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index ba3fbb707fed7..bc40943521829 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -15,7 +15,6 @@ namespace mlir {
 namespace python {
 
 void populatePassManagerSubmodule(nanobind::module_ &m);
-void populatePassSubmodule(nanobind::module_ &m);
 
 } // namespace python
 } // namespace mlir
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 3c499c3e4974d..b0a6ec1ace3cc 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
       : Pass(passID, opName), id(passID), name(name), argument(argument),
         description(description), dependentDialects(dependentDialects),
         callbacks(callbacks), userData(userData) {
-    callbacks.construct(userData);
+    if (callbacks.construct)
+      callbacks.construct(userData);
   }
 
-  ~ExternalPass() override { callbacks.destruct(userData); }
+  ~ExternalPass() override {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
 
   StringRef getName() const override { return name; }
   StringRef getArgument() const override { return argument; }
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index fde53a4d64d1c..7a0c95ebb8200 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -20,7 +20,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
   SOURCES
     _mlir_libs/__init__.py
     ir.py
-    passes.py
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
diff --git a/mlir/python/mlir/passes.py b/mlir/python/mlir/passes.py
deleted file mode 100644
index aab9d6b252bbc..0000000000000
--- a/mlir/python/mlir/passes.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._mlir_libs._mlir.passes import *
diff --git a/mlir/test/python/pass.py b/mlir/test/python/python_pass.py
similarity index 83%
rename from mlir/test/python/pass.py
rename to mlir/test/python/python_pass.py
index 0f6d818e850dc..b83f7484ab506 100644
--- a/mlir/test/python/pass.py
+++ b/mlir/test/python/python_pass.py
@@ -3,7 +3,6 @@
 import gc, sys
 from mlir.ir import *
 from mlir.passmanager import *
-from mlir.passes import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import pdl
 from mlir.rewrite import *
@@ -54,13 +53,6 @@ def testCustomPass():
         pdl_module = make_pdl_module()
         frozen = PDLModule(pdl_module).freeze()
 
-        class CustomPass(Pass):
-            def __init__(self):
-                super().__init__("CustomPass", op_name="builtin.module")
-
-            def run(self, m):
-                apply_patterns_and_fold_greedily_with_op(m, frozen)
-
         module = ModuleOp.parse(
             r"""
             module {
@@ -72,12 +64,24 @@ def run(self, m):
         """
         )
 
+        def custom_pass_1(op):
+            print("hello from pass 1!!!", file=sys.stderr)
+
+        class CustomPass2:
+            def __call__(self, m):
+                apply_patterns_and_fold_greedily_with_op(m, frozen)
+
+        custom_pass_2 = CustomPass2()
+
         pm = PassManager("any")
         pm.enable_ir_printing()
 
-        # CHECK-LABEL: Dump After CustomPass
+        # CHECK: hello from pass 1!!!
+        # CHECK-LABEL: Dump After custom_pass_1
+        pm.add_python_pass(custom_pass_1)
+        # CHECK-LABEL: Dump After CustomPass2
         # CHECK: arith.muli
-        pm.add(CustomPass())
+        pm.add_python_pass(custom_pass_2, "CustomPass2")
         # CHECK-LABEL: Dump After ArithToLLVMConversionPass
         # CHECK: llvm.mul
         pm.add("convert-arith-to-llvm")

>From 9f526c7fa1cf2793f14730e6cb55854800867fd1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 8 Sep 2025 15:58:13 +0800
Subject: [PATCH 15/15] fix rewrite.cpp

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e764535c1a4c0..731ad7a3b529b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,9 +9,9 @@
 #include "Rewrite.h"
 
 #include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/Rewrite.h"
 #include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir/Config/mlir-config.h"
 
 namespace nb = nanobind;
@@ -103,8 +103,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
        [](MlirModule module, MlirFrozenRewritePatternSet set) {
          auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
          if (mlirLogicalResultIsFailure(status))
-           // FIXME: Not sure this is the right error to throw here.
-           throw nb::value_error("pattern application failed to converge");
+           throw std::runtime_error("pattern application failed to converge");
        },
        "module"_a, "set"_a,
        "Applys the given patterns to the given module greedily while folding "
@@ -114,8 +113,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           [](MlirOperation op, MlirFrozenRewritePatternSet set) {
             auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
             if (mlirLogicalResultIsFailure(status))
-              // FIXME: Not sure this is the right error to throw here.
-              throw nb::value_error("pattern application failed to converge");
+              throw std::runtime_error(
+                  "pattern application failed to converge");
           },
           "op"_a, "set"_a,
           "Applys the given patterns to the given op greedily while folding "



More information about the Mlir-commits mailing list