[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)
Maksim Levental
llvmlistbot at llvm.org
Fri Oct 10 10:58:39 PDT 2025
================
@@ -165,13 +175,116 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};
+class PyRewritePatternSet {
+public:
+ PyRewritePatternSet(MlirContext ctx)
+ : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+ ~PyRewritePatternSet() {
+ if (set.ptr)
+ mlirRewritePatternSetDestroy(set);
+ }
+
+ void add(MlirStringRef rootName, unsigned benefit,
+ const nb::callable &matchAndRewrite) {
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+ nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
----------------
makslevental wrote:
so is there value in passing `pattern` to the callback? even in cpp i've rarely seen any of the `Pattern` members called (like `getGeneratedOps` and etc).
https://github.com/llvm/llvm-project/pull/162699
More information about the Mlir-commits
mailing list