[Mlir-commits] [mlir] [MLIR][Python][NO MERGE] Support Python-defined passes in MLIR (PR #157369)

Maksim Levental llvmlistbot at llvm.org
Sun Sep 7 20:49:22 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/157369

>From dc93dba39c0fee79b172d3f0dba11f7a25784dd0 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sun, 7 Sep 2025 19:10:29 -0400
Subject: [PATCH] [MLIR][Python] Support Python-defined passes in MLIR

based heavily on https://github.com/llvm/llvm-project/pull/156000
---
 mlir/include/mlir-c/Rewrite.h           |  4 ++
 mlir/lib/Bindings/Python/MainModule.cpp |  4 +-
 mlir/lib/Bindings/Python/Pass.cpp       | 43 +++++++++++++
 mlir/lib/Bindings/Python/Rewrite.cpp    | 34 ++++++++---
 mlir/lib/CAPI/IR/Pass.cpp               | 12 +++-
 mlir/lib/CAPI/Transforms/Rewrite.cpp    |  7 +++
 mlir/test/python/pass.py                | 81 +++++++++++++++++++++++++
 7 files changed, 170 insertions(+), 15 deletions(-)
 create mode 100644 mlir/test/python/pass.py

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+    MlirOperation op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..d7282b3d6f713 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
   populateRewriteSubmodule(rewriteModule);
 
   // Define and populate PassManager submodule.
-  auto passModule =
+  auto passManagerModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
-  populatePassManagerSubmodule(passModule);
+  populatePassManagerSubmodule(passManagerModule);
 }
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 88e28dca76bb9..a0d11d1a6276c 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -10,8 +10,10 @@
 
 #include "IRModule.h"
 #include "mlir-c/Pass.h"
+// clang-format off
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -157,6 +159,47 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
+      .def(
+          "add_python_pass",
+          [](PyPassManager &passManager, const nb::callable &run,
+             std::optional<std::string> &name, const std::string &argument,
+             const std::string &description, const std::string &opName) {
+            if (!name.has_value()) {
+              name = nb::cast<std::string>(
+                  nb::borrow<nb::str>(run.attr("__name__")));
+            }
+
+            MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+            MlirTypeID passID =
+                mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+            MlirExternalPassCallbacks callbacks;
+            callbacks.construct = [](void *obj) {
+              (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+            };
+            callbacks.destruct = [](void *obj) {
+              (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+            };
+            callbacks.clone = [](void *) -> void * {
+              throw std::runtime_error("Cloning Python passes not supported");
+            };
+            callbacks.run = [](MlirOperation op, MlirExternalPass,
+                               void *userData) {
+              nb::steal<nb::callable>(static_cast<PyObject *>(userData))(op);
+            };
+            callbacks.clone = nullptr;
+            callbacks.initialize = nullptr;
+            auto externalPass = mlirCreateExternalPass(
+                passID, mlirStringRefCreate(name->data(), name->length()),
+                mlirStringRefCreate(argument.data(), argument.length()),
+                mlirStringRefCreate(description.data(), description.length()),
+                mlirStringRefCreate(opName.data(), opName.size()),
+                /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
+                callbacks, /*userData*/ run.ptr());
+            mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
+          },
+          "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
+          "description"_a.none() = "", "op_name"_a.none() = "",
+          "Add a python-defined pass to the pass manager.")
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op) {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..191a7cca53d93 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -10,8 +10,10 @@
 
 #include "IRModule.h"
 #include "mlir-c/Rewrite.h"
+// clang-format off
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
 #include "mlir/Config/mlir-config.h"
 
 namespace nb = nanobind;
@@ -99,14 +101,26 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
-      "apply_patterns_and_fold_greedily",
-      [](MlirModule module, MlirFrozenRewritePatternSet set) {
-        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
-        if (mlirLogicalResultIsFailure(status))
-          // FIXME: Not sure this is the right error to throw here.
-          throw nb::value_error("pattern application failed to converge");
-      },
-      "module"_a, "set"_a,
-      "Applys the given patterns to the given module greedily while folding "
-      "results.");
+       "apply_patterns_and_fold_greedily",
+       [](MlirModule module, MlirFrozenRewritePatternSet set) {
+         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+         if (mlirLogicalResultIsFailure(status))
+           // FIXME: Not sure this is the right error to throw here.
+           throw nb::value_error("pattern application failed to converge");
+       },
+       "module"_a, "set"_a,
+       "Applys the given patterns to the given module greedily while folding "
+       "results.")
+      .def(
+          "apply_patterns_and_fold_greedily_with_op",
+          [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+            auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
+            if (mlirLogicalResultIsFailure(status)) {
+              throw std::runtime_error(
+                  "pattern application failed to converge");
+            }
+          },
+          "op"_a, "set"_a,
+          "Applys the given patterns to the given op greedily while folding "
+          "results.");
 }
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 3c499c3e4974d..52d6a0ad75d26 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
       : Pass(passID, opName), id(passID), name(name), argument(argument),
         description(description), dependentDialects(dependentDialects),
         callbacks(callbacks), userData(userData) {
-    callbacks.construct(userData);
+    if (callbacks.construct)
+      callbacks.construct(userData);
   }
 
-  ~ExternalPass() override { callbacks.destruct(userData); }
+  ~ExternalPass() override {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
 
   StringRef getName() const override { return name; }
   StringRef getArgument() const override { return argument; }
@@ -180,7 +184,9 @@ class ExternalPass : public Pass {
   }
 
   std::unique_ptr<Pass> clonePass() const override {
-    void *clonedUserData = callbacks.clone(userData);
+    void *clonedUserData;
+    if (callbacks.clone)
+      clonedUserData = callbacks.clone(userData);
     return std::make_unique<ExternalPass>(id, name, argument, description,
                                           getOpName(), dependentDialects,
                                           callbacks, clonedUserData);
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+                                       MlirFrozenRewritePatternSet patterns,
+                                       MlirGreedyRewriteDriverConfig) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
new file mode 100644
index 0000000000000..e06fc77d9708f
--- /dev/null
+++ b/mlir/test/python/pass.py
@@ -0,0 +1,81 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def run(f):
+    # Note, everything in this file is dumped to stderr because that's where
+    # `IR Dump After` dumps too (so we can't cross the "streams")
+    print("\nTEST:", f.__name__, file=sys.stderr)
+    f()
+    gc.collect()
+
+
+def make_pdl_module():
+    with Location.unknown():
+        pdl_module = Module.create()
+        with InsertionPoint(pdl_module.body):
+            # Change all arith.addi with index types to arith.muli.
+            @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+            def pat():
+                # Match arith.addi with index types.
+                i64_type = pdl.TypeOp(IntegerType.get_signless(64))
+                operand0 = pdl.OperandOp(i64_type)
+                operand1 = pdl.OperandOp(i64_type)
+                op0 = pdl.OperationOp(
+                    name="arith.addi", args=[operand0, operand1], types=[i64_type]
+                )
+
+                # Replace the matched op with arith.muli.
+                @pdl.rewrite()
+                def rew():
+                    newOp = pdl.OperationOp(
+                        name="arith.muli", args=[operand0, operand1], types=[i64_type]
+                    )
+                    pdl.ReplaceOp(op0, with_op=newOp)
+
+        return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+ at run
+def testCustomPass():
+    with Context():
+        pdl_module = make_pdl_module()
+        frozen = PDLModule(pdl_module).freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+        """
+        )
+
+        def custom_pass_1(op):
+            print("hello from pass 1!!!", file=sys.stderr)
+
+        def custom_pass_2(op):
+            apply_patterns_and_fold_greedily_with_op(op, frozen)
+
+        pm = PassManager("any")
+        pm.enable_ir_printing()
+
+        # CHECK: hello from pass 1!!!
+        # CHECK-LABEL: Dump After custom_pass_1
+        # CHECK-LABEL: Dump After CustomPass2
+        # CHECK: arith.muli
+        pm.add_python_pass(custom_pass_1)
+        pm.add_python_pass(custom_pass_2, "CustomPass2")
+        # # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+        # # CHECK: llvm.mul
+        pm.add("convert-arith-to-llvm")
+        pm.run(module)



More information about the Mlir-commits mailing list