[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 9 09:54:02 PDT 2025
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/162699
This PR adds support for defining custom **`RewritePattern`** implementations directly in the Python bindings.
Previously, users could define similar patterns using the PDL dialect’s bindings. However, for more complex patterns, this often required writing multiple Python callbacks as PDL native constraints or rewrite functions, which made the overall logic less intuitive—though it could be more performant than a pure Python implementation.
With this change, we introduce an additional, more straightforward way to define patterns purely in Python, complementing the existing PDL-based approach.
### Example
```python
def to_muli(op, rewriter, pattern):
with rewriter.ip:
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
rewriter.replace_op(op, new_op.owner)
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
frozen = patterns.freeze()
module = ...
apply_patterns_and_fold_greedily(module, frozen)
```
>From 4689bc244c266bdabb7c8416ff06f9face681c45 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:03:10 +0800
Subject: [PATCH] [MLIR][Python] Support Python-defined rewrite patterns
---
mlir/include/mlir-c/Rewrite.h | 33 +++++++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 81 +++++++++++++++++++++++++-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 87 +++++++++++++++++++++++++++-
mlir/test/python/rewrite.py | 49 ++++++++++++++++
4 files changed, 245 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/python/rewrite.py
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5dd285ee076c4..68bb112404170 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
@@ -324,6 +325,38 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+typedef unsigned short MlirPatternBenefit;
+
+typedef struct {
+ void (*construct)(void *userData);
+ void (*destruct)(void *userData);
+ MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
+ MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData);
+} MlirRewritePatternCallbacks;
+
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetCreate(MlirContext context);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern);
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d9703c82e8..3740c59e62001 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,14 @@ class PyPatternRewriter {
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
@@ -165,13 +173,82 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};
+class PyRewritePatternSet {
+public:
+ PyRewritePatternSet(MlirContext ctx)
+ : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+ ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
+
+ void add(MlirStringRef rootName, MlirPatternBenefit benefit,
+ const nb::callable &matchAndRewrite) {
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+ nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
+ return logicalResultFromObject(res);
+ };
+ MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(set, pattern);
+ }
+
+ PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
+
+private:
+ MlirRewritePatternSet set;
+ MlirContext ctx;
+};
+
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of the PatternRewriter
+ //----------------------------------------------------------------------------
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.");
+ "The current insertion point of the PatternRewriter.")
+ .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+ MlirOperation newOp) { self.replaceOp(op, newOp); })
+ .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+ const std::vector<MlirValue> &values) {
+ self.replaceOp(op, values);
+ });
+
+ //----------------------------------------------------------------------------
+ // Mapping of the RewritePatternSet
+ //----------------------------------------------------------------------------
+ nb::class_<MlirRewritePattern>(m, "RewritePattern");
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ unsigned benefit) {
+ std::string opName =
+ nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ fn);
+ },
+ "root"_a, "fn"_a, "benefit"_a = 1)
+ .def("freeze", &PyRewritePatternSet::freeze);
+
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
@@ -237,7 +314,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"freeze",
[](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b991f5d..f3430e2e78978 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -270,9 +271,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
- return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+ return static_cast<mlir::RewritePatternSet *>(module.ptr);
}
static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
@@ -291,7 +292,7 @@ wrap(mlir::FrozenRewritePatternSet *module) {
}
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
- auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
op.ptr = nullptr;
return wrap(m);
}
@@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) {
+ assert(pattern.ptr && "unexpected null pattern");
+ return static_cast<const mlir::RewritePattern *>(pattern.ptr);
+}
+
+inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) {
+ return {pattern};
+}
+
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+ ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+ StringRef rootName, PatternBenefit benefit,
+ MLIRContext *context,
+ ArrayRef<StringRef> generatedNames)
+ : RewritePattern(rootName, benefit, context, generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalRewritePattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+ wrap(&rewriter), userData));
+ }
+
+private:
+ MlirRewritePatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
+ return wrap(new mlir::ExternalRewritePattern(
+ callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+ unwrap(context), generatedNamesVec));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+ return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+ delete unwrap(set);
+}
+
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern) {
+ std::unique_ptr<mlir::RewritePattern> patternPtr(
+ const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+ pattern.ptr = nullptr;
+ unwrap(set)->add(std::move(patternPtr));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..6aed936f94d87
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,49 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+# CHECK-LABEL: TEST: testRewritePattern
+ at run
+def testRewritePattern():
+ def to_muli(op, rewriter, pattern):
+ with rewriter.ip:
+ new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ rewriter.replace_op(op, new_op.owner)
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.AddIOp, to_muli)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
More information about the Mlir-commits
mailing list