[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 9 19:36:01 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/162699
>From 4689bc244c266bdabb7c8416ff06f9face681c45 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:03:10 +0800
Subject: [PATCH 1/3] [MLIR][Python] Support Python-defined rewrite patterns
---
mlir/include/mlir-c/Rewrite.h | 33 +++++++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 81 +++++++++++++++++++++++++-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 87 +++++++++++++++++++++++++++-
mlir/test/python/rewrite.py | 49 ++++++++++++++++
4 files changed, 245 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/python/rewrite.py
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5dd285ee076c4..68bb112404170 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
@@ -324,6 +325,38 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+typedef unsigned short MlirPatternBenefit;
+
+typedef struct {
+ void (*construct)(void *userData);
+ void (*destruct)(void *userData);
+ MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
+ MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData);
+} MlirRewritePatternCallbacks;
+
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetCreate(MlirContext context);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern);
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d9703c82e8..3740c59e62001 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,14 @@ class PyPatternRewriter {
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
@@ -165,13 +173,82 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};
+class PyRewritePatternSet {
+public:
+ PyRewritePatternSet(MlirContext ctx)
+ : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+ ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
+
+ void add(MlirStringRef rootName, MlirPatternBenefit benefit,
+ const nb::callable &matchAndRewrite) {
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+ nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
+ return logicalResultFromObject(res);
+ };
+ MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(set, pattern);
+ }
+
+ PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
+
+private:
+ MlirRewritePatternSet set;
+ MlirContext ctx;
+};
+
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of the PatternRewriter
+ //----------------------------------------------------------------------------
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.");
+ "The current insertion point of the PatternRewriter.")
+ .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+ MlirOperation newOp) { self.replaceOp(op, newOp); })
+ .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+ const std::vector<MlirValue> &values) {
+ self.replaceOp(op, values);
+ });
+
+ //----------------------------------------------------------------------------
+ // Mapping of the RewritePatternSet
+ //----------------------------------------------------------------------------
+ nb::class_<MlirRewritePattern>(m, "RewritePattern");
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ unsigned benefit) {
+ std::string opName =
+ nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ fn);
+ },
+ "root"_a, "fn"_a, "benefit"_a = 1)
+ .def("freeze", &PyRewritePatternSet::freeze);
+
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
@@ -237,7 +314,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"freeze",
[](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b991f5d..f3430e2e78978 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -270,9 +271,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
- return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+ return static_cast<mlir::RewritePatternSet *>(module.ptr);
}
static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
@@ -291,7 +292,7 @@ wrap(mlir::FrozenRewritePatternSet *module) {
}
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
- auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
op.ptr = nullptr;
return wrap(m);
}
@@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) {
+ assert(pattern.ptr && "unexpected null pattern");
+ return static_cast<const mlir::RewritePattern *>(pattern.ptr);
+}
+
+inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) {
+ return {pattern};
+}
+
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+ ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+ StringRef rootName, PatternBenefit benefit,
+ MLIRContext *context,
+ ArrayRef<StringRef> generatedNames)
+ : RewritePattern(rootName, benefit, context, generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalRewritePattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+ wrap(&rewriter), userData));
+ }
+
+private:
+ MlirRewritePatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
+ return wrap(new mlir::ExternalRewritePattern(
+ callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+ unwrap(context), generatedNamesVec));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+ return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+ delete unwrap(set);
+}
+
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern) {
+ std::unique_ptr<mlir::RewritePattern> patternPtr(
+ const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+ pattern.ptr = nullptr;
+ unwrap(set)->add(std::move(patternPtr));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..6aed936f94d87
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,49 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+# CHECK-LABEL: TEST: testRewritePattern
+ at run
+def testRewritePattern():
+ def to_muli(op, rewriter, pattern):
+ with rewriter.ip:
+ new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ rewriter.replace_op(op, new_op.owner)
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.AddIOp, to_muli)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
>From 61b87af618652e26f71400d7b238f1597a2ca364 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:56:46 +0800
Subject: [PATCH 2/3] format
---
mlir/test/python/rewrite.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 6aed936f94d87..c7b6c1f19991e 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -19,6 +19,7 @@ def run(f):
gc.collect()
assert Context._get_live_count() == 0
+
# CHECK-LABEL: TEST: testRewritePattern
@run
def testRewritePattern():
>From 395627f8987187ca8f45dcefe6c9167b69a3f7d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 10:35:15 +0800
Subject: [PATCH 3/3] add docs for C API
---
mlir/include/mlir-c/Rewrite.h | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 68bb112404170..cc021bcfba889 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -329,17 +329,28 @@ mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
/// RewritePattern API
//===----------------------------------------------------------------------===//
+/// PatternBenefit represents the benefit of a pattern match.
typedef unsigned short MlirPatternBenefit;
+/// Callbacks to construct a rewrite pattern.
typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
void (*destruct)(void *userData);
+ /// The callback function to match against code rooted at the specified
+ /// operation, and perform the rewrite if the match is successful,
+ /// corresponding to RewritePattern::matchAndRewrite.
MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
MlirOperation op,
MlirPatternRewriter rewriter,
void *userData);
} MlirRewritePatternCallbacks;
+/// Create a rewrite pattern that matches the operation
+/// with the given rootName, corresponding to mlir::OpRewritePattern.
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
@@ -349,11 +360,14 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
/// RewritePatternSet API
//===----------------------------------------------------------------------===//
+/// Create an empty MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetCreate(MlirContext context);
+/// Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
MlirRewritePattern pattern);
More information about the Mlir-commits
mailing list