[Mlir-commits] [mlir] [MLIR][Python][NO MERGE] Support Python-defined passes in MLIR (PR #157369)
Maksim Levental
llvmlistbot at llvm.org
Sun Sep 7 20:49:22 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/157369
>From dc93dba39c0fee79b172d3f0dba11f7a25784dd0 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sun, 7 Sep 2025 19:10:29 -0400
Subject: [PATCH] [MLIR][Python] Support Python-defined passes in MLIR
based heavily on https://github.com/llvm/llvm-project/pull/156000
---
mlir/include/mlir-c/Rewrite.h | 4 ++
mlir/lib/Bindings/Python/MainModule.cpp | 4 +-
mlir/lib/Bindings/Python/Pass.cpp | 43 +++++++++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 34 ++++++++---
mlir/lib/CAPI/IR/Pass.cpp | 12 +++-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 7 +++
mlir/test/python/pass.py | 81 +++++++++++++++++++++++++
7 files changed, 170 insertions(+), 15 deletions(-)
create mode 100644 mlir/test/python/pass.py
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+ MlirOperation op, MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig);
+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..d7282b3d6f713 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
populateRewriteSubmodule(rewriteModule);
// Define and populate PassManager submodule.
- auto passModule =
+ auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
- populatePassManagerSubmodule(passModule);
+ populatePassManagerSubmodule(passManagerModule);
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 88e28dca76bb9..a0d11d1a6276c 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -10,8 +10,10 @@
#include "IRModule.h"
#include "mlir-c/Pass.h"
+// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
namespace nb = nanobind;
using namespace nb::literals;
@@ -157,6 +159,47 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
+ .def(
+ "add_python_pass",
+ [](PyPassManager &passManager, const nb::callable &run,
+ std::optional<std::string> &name, const std::string &argument,
+ const std::string &description, const std::string &opName) {
+ if (!name.has_value()) {
+ name = nb::cast<std::string>(
+ nb::borrow<nb::str>(run.attr("__name__")));
+ }
+
+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+ MlirTypeID passID =
+ mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirExternalPassCallbacks callbacks;
+ callbacks.construct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+ };
+ callbacks.destruct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+ };
+ callbacks.clone = [](void *) -> void * {
+ throw std::runtime_error("Cloning Python passes not supported");
+ };
+ callbacks.run = [](MlirOperation op, MlirExternalPass,
+ void *userData) {
+ nb::steal<nb::callable>(static_cast<PyObject *>(userData))(op);
+ };
+ callbacks.clone = nullptr;
+ callbacks.initialize = nullptr;
+ auto externalPass = mlirCreateExternalPass(
+ passID, mlirStringRefCreate(name->data(), name->length()),
+ mlirStringRefCreate(argument.data(), argument.length()),
+ mlirStringRefCreate(description.data(), description.length()),
+ mlirStringRefCreate(opName.data(), opName.size()),
+ /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
+ callbacks, /*userData*/ run.ptr());
+ mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
+ },
+ "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
+ "description"_a.none() = "", "op_name"_a.none() = "",
+ "Add a python-defined pass to the pass manager.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op) {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..191a7cca53d93 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -10,8 +10,10 @@
#include "IRModule.h"
#include "mlir-c/Rewrite.h"
+// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir/Config/mlir-config.h"
namespace nb = nanobind;
@@ -99,14 +101,26 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
- "apply_patterns_and_fold_greedily",
- [](MlirModule module, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
- if (mlirLogicalResultIsFailure(status))
- // FIXME: Not sure this is the right error to throw here.
- throw nb::value_error("pattern application failed to converge");
- },
- "module"_a, "set"_a,
- "Applys the given patterns to the given module greedily while folding "
- "results.");
+ "apply_patterns_and_fold_greedily",
+ [](MlirModule module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ // FIXME: Not sure this is the right error to throw here.
+ throw nb::value_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.")
+ .def(
+ "apply_patterns_and_fold_greedily_with_op",
+ [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
+ if (mlirLogicalResultIsFailure(status)) {
+ throw std::runtime_error(
+ "pattern application failed to converge");
+ }
+ },
+ "op"_a, "set"_a,
+ "Applys the given patterns to the given op greedily while folding "
+ "results.");
}
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 3c499c3e4974d..52d6a0ad75d26 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
: Pass(passID, opName), id(passID), name(name), argument(argument),
description(description), dependentDialects(dependentDialects),
callbacks(callbacks), userData(userData) {
- callbacks.construct(userData);
+ if (callbacks.construct)
+ callbacks.construct(userData);
}
- ~ExternalPass() override { callbacks.destruct(userData); }
+ ~ExternalPass() override {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
StringRef getName() const override { return name; }
StringRef getArgument() const override { return argument; }
@@ -180,7 +184,9 @@ class ExternalPass : public Pass {
}
std::unique_ptr<Pass> clonePass() const override {
- void *clonedUserData = callbacks.clone(userData);
+ void *clonedUserData;
+ if (callbacks.clone)
+ clonedUserData = callbacks.clone(userData);
return std::make_unique<ExternalPass>(id, name, argument, description,
getOpName(), dependentDialects,
callbacks, clonedUserData);
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass.py b/mlir/test/python/pass.py
new file mode 100644
index 0000000000000..e06fc77d9708f
--- /dev/null
+++ b/mlir/test/python/pass.py
@@ -0,0 +1,81 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def run(f):
+ # Note, everything in this file is dumped to stderr because that's where
+ # `IR Dump After` dumps too (so we can't cross the "streams")
+ print("\nTEST:", f.__name__, file=sys.stderr)
+ f()
+ gc.collect()
+
+
+def make_pdl_module():
+ with Location.unknown():
+ pdl_module = Module.create()
+ with InsertionPoint(pdl_module.body):
+ # Change all arith.addi with index types to arith.muli.
+ @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+ def pat():
+ # Match arith.addi with index types.
+ i64_type = pdl.TypeOp(IntegerType.get_signless(64))
+ operand0 = pdl.OperandOp(i64_type)
+ operand1 = pdl.OperandOp(i64_type)
+ op0 = pdl.OperationOp(
+ name="arith.addi", args=[operand0, operand1], types=[i64_type]
+ )
+
+ # Replace the matched op with arith.muli.
+ @pdl.rewrite()
+ def rew():
+ newOp = pdl.OperationOp(
+ name="arith.muli", args=[operand0, operand1], types=[i64_type]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+ at run
+def testCustomPass():
+ with Context():
+ pdl_module = make_pdl_module()
+ frozen = PDLModule(pdl_module).freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ def custom_pass_1(op):
+ print("hello from pass 1!!!", file=sys.stderr)
+
+ def custom_pass_2(op):
+ apply_patterns_and_fold_greedily_with_op(op, frozen)
+
+ pm = PassManager("any")
+ pm.enable_ir_printing()
+
+ # CHECK: hello from pass 1!!!
+ # CHECK-LABEL: Dump After custom_pass_1
+ # CHECK-LABEL: Dump After CustomPass2
+ # CHECK: arith.muli
+ pm.add_python_pass(custom_pass_1)
+ pm.add_python_pass(custom_pass_2, "CustomPass2")
+ # # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+ # # CHECK: llvm.mul
+ pm.add("convert-arith-to-llvm")
+ pm.run(module)
More information about the Mlir-commits
mailing list