[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 9 22:09:33 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/162699
>From 4689bc244c266bdabb7c8416ff06f9face681c45 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:03:10 +0800
Subject: [PATCH 01/10] [MLIR][Python] Support Python-defined rewrite patterns
---
mlir/include/mlir-c/Rewrite.h | 33 +++++++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 81 +++++++++++++++++++++++++-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 87 +++++++++++++++++++++++++++-
mlir/test/python/rewrite.py | 49 ++++++++++++++++
4 files changed, 245 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/python/rewrite.py
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5dd285ee076c4..68bb112404170 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
@@ -324,6 +325,38 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+typedef unsigned short MlirPatternBenefit;
+
+typedef struct {
+ void (*construct)(void *userData);
+ void (*destruct)(void *userData);
+ MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
+ MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData);
+} MlirRewritePatternCallbacks;
+
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetCreate(MlirContext context);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern);
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d9703c82e8..3740c59e62001 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,14 @@ class PyPatternRewriter {
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
@@ -165,13 +173,82 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};
+class PyRewritePatternSet {
+public:
+ PyRewritePatternSet(MlirContext ctx)
+ : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+ ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
+
+ void add(MlirStringRef rootName, MlirPatternBenefit benefit,
+ const nb::callable &matchAndRewrite) {
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+ nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
+ return logicalResultFromObject(res);
+ };
+ MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(set, pattern);
+ }
+
+ PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
+
+private:
+ MlirRewritePatternSet set;
+ MlirContext ctx;
+};
+
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of the PatternRewriter
+ //----------------------------------------------------------------------------
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.");
+ "The current insertion point of the PatternRewriter.")
+ .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+ MlirOperation newOp) { self.replaceOp(op, newOp); })
+ .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+ const std::vector<MlirValue> &values) {
+ self.replaceOp(op, values);
+ });
+
+ //----------------------------------------------------------------------------
+ // Mapping of the RewritePatternSet
+ //----------------------------------------------------------------------------
+ nb::class_<MlirRewritePattern>(m, "RewritePattern");
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ unsigned benefit) {
+ std::string opName =
+ nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ fn);
+ },
+ "root"_a, "fn"_a, "benefit"_a = 1)
+ .def("freeze", &PyRewritePatternSet::freeze);
+
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
@@ -237,7 +314,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"freeze",
[](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b991f5d..f3430e2e78978 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -270,9 +271,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
- return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+ return static_cast<mlir::RewritePatternSet *>(module.ptr);
}
static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
@@ -291,7 +292,7 @@ wrap(mlir::FrozenRewritePatternSet *module) {
}
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
- auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
op.ptr = nullptr;
return wrap(m);
}
@@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) {
+ assert(pattern.ptr && "unexpected null pattern");
+ return static_cast<const mlir::RewritePattern *>(pattern.ptr);
+}
+
+inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) {
+ return {pattern};
+}
+
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+ ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+ StringRef rootName, PatternBenefit benefit,
+ MLIRContext *context,
+ ArrayRef<StringRef> generatedNames)
+ : RewritePattern(rootName, benefit, context, generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalRewritePattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+ wrap(&rewriter), userData));
+ }
+
+private:
+ MlirRewritePatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
+ return wrap(new mlir::ExternalRewritePattern(
+ callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+ unwrap(context), generatedNamesVec));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+ return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+ delete unwrap(set);
+}
+
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern) {
+ std::unique_ptr<mlir::RewritePattern> patternPtr(
+ const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+ pattern.ptr = nullptr;
+ unwrap(set)->add(std::move(patternPtr));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..6aed936f94d87
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,49 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+# CHECK-LABEL: TEST: testRewritePattern
+ at run
+def testRewritePattern():
+ def to_muli(op, rewriter, pattern):
+ with rewriter.ip:
+ new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ rewriter.replace_op(op, new_op.owner)
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.AddIOp, to_muli)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
>From 61b87af618652e26f71400d7b238f1597a2ca364 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:56:46 +0800
Subject: [PATCH 02/10] format
---
mlir/test/python/rewrite.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 6aed936f94d87..c7b6c1f19991e 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -19,6 +19,7 @@ def run(f):
gc.collect()
assert Context._get_live_count() == 0
+
# CHECK-LABEL: TEST: testRewritePattern
@run
def testRewritePattern():
>From 395627f8987187ca8f45dcefe6c9167b69a3f7d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 10:35:15 +0800
Subject: [PATCH 03/10] add docs for C API
---
mlir/include/mlir-c/Rewrite.h | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 68bb112404170..cc021bcfba889 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -329,17 +329,28 @@ mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
/// RewritePattern API
//===----------------------------------------------------------------------===//
+/// PatternBenefit represents the benefit of a pattern match.
typedef unsigned short MlirPatternBenefit;
+/// Callbacks to construct a rewrite pattern.
typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
void (*destruct)(void *userData);
+ /// The callback function to match against code rooted at the specified
+ /// operation, and perform the rewrite if the match is successful,
+ /// corresponding to RewritePattern::matchAndRewrite.
MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
MlirOperation op,
MlirPatternRewriter rewriter,
void *userData);
} MlirRewritePatternCallbacks;
+/// Create a rewrite pattern that matches the operation
+/// with the given rootName, corresponding to mlir::OpRewritePattern.
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
@@ -349,11 +360,14 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
/// RewritePatternSet API
//===----------------------------------------------------------------------===//
+/// Create an empty MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetCreate(MlirContext context);
+/// Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
MlirRewritePattern pattern);
>From 0ddd081a3eb27348c7b87058edcf8eb437c796a0 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 10:45:09 +0800
Subject: [PATCH 04/10] add more docs and fix some name
---
mlir/include/mlir-c/Rewrite.h | 10 ++++++++--
mlir/lib/Bindings/Python/Rewrite.cpp | 6 +++++-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 13 +++++++------
3 files changed, 20 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index cc021bcfba889..66a9a5de1669d 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -303,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
/// FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
+/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet.
+/// Note that the ownership of the input set is transferred into the frozen set
+/// after this call.
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
-mlirFreezeRewritePattern(MlirRewritePatternSet op);
+mlirFreezeRewritePattern(MlirRewritePatternSet set);
+/// Destroy the given MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED void
-mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set);
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
MlirOperation op, MlirFrozenRewritePatternSet patterns,
@@ -368,6 +372,8 @@ mlirRewritePatternSetCreate(MlirContext context);
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
+/// Note that the ownership of the pattern is transferred to the set after this
+/// call.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
MlirRewritePattern pattern);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 3740c59e62001..9c99c6a4366b5 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -202,7 +202,11 @@ class PyRewritePatternSet {
mlirRewritePatternSetAdd(set, pattern);
}
- PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
+ PyFrozenRewritePatternSet freeze() {
+ MlirRewritePatternSet s = set;
+ set.ptr = nullptr;
+ return mlirFreezeRewritePattern(s);
+ }
private:
MlirRewritePatternSet set;
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index f3430e2e78978..7e7a4f7715bb4 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -291,15 +291,16 @@ wrap(mlir::FrozenRewritePatternSet *module) {
return {module};
}
-MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
- auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
- op.ptr = nullptr;
+MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet set) {
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
+ set.ptr = nullptr;
return wrap(m);
}
-void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
- delete unwrap(op);
- op.ptr = nullptr;
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
+ delete unwrap(set);
+ set.ptr = nullptr;
}
MlirLogicalResult
>From da4bb8b560b3bc49d5064281ef407c618d24787c Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 11:28:45 +0800
Subject: [PATCH 05/10] add nb::sigs and python api docs
---
mlir/lib/Bindings/Python/Rewrite.cpp | 55 ++++++++++++++++++++++------
1 file changed, 43 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9c99c6a4366b5..07559457f2f2f 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -53,6 +53,8 @@ class PyPatternRewriter {
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
}
+ void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
@@ -177,7 +179,10 @@ class PyRewritePatternSet {
public:
PyRewritePatternSet(MlirContext ctx)
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
- ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
+ ~PyRewritePatternSet() {
+ if (set.ptr)
+ mlirRewritePatternSetDestroy(set);
+ }
void add(MlirStringRef rootName, MlirPatternBenefit benefit,
const nb::callable &matchAndRewrite) {
@@ -220,15 +225,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
- nb::class_<PyPatternRewriter>(m, "PatternRewriter")
- .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.")
- .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
- MlirOperation newOp) { self.replaceOp(op, newOp); })
- .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
- const std::vector<MlirValue> &values) {
- self.replaceOp(op, values);
- });
+ nb::
+ class_<PyPatternRewriter>(m, "PatternRewriter")
+ .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.")
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ MlirOperation newOp) { self.replaceOp(op, newOp); },
+ "Replace an operation with a new operation.",
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+ ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ )
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ const std::vector<MlirValue> &values) {
+ self.replaceOp(op, values);
+ },
+ "Replace an operation with a list of values.",
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+ ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+ // clang-format on
+ )
+ .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+ // clang-format off
+ nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ );
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
@@ -250,8 +277,12 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
fn);
},
- "root"_a, "fn"_a, "benefit"_a = 1)
- .def("freeze", &PyRewritePatternSet::freeze);
+ "root"_a, "fn"_a, "benefit"_a = 1,
+ "Add a new rewrite pattern on the given root operation with the "
+ "callable as the matching and rewriting function and the given "
+ "benefit.")
+ .def("freeze", &PyRewritePatternSet::freeze,
+ "Freeze the pattern set into a frozen one.");
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
>From 5333a6ef08a3286b83494089b568f0f4087f77a7 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 11:42:21 +0800
Subject: [PATCH 06/10] add more examples
---
mlir/lib/CAPI/Transforms/Rewrite.cpp | 1 -
mlir/test/python/rewrite.py | 27 +++++++++++++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 7e7a4f7715bb4..d7c8e53f2bba6 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -17,7 +17,6 @@
#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index c7b6c1f19991e..cbc3a4043f96c 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -28,9 +28,18 @@ def to_muli(op, rewriter, pattern):
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
rewriter.replace_op(op, new_op.owner)
+ def constant_1_to_2(op, rewriter, pattern):
+ c = op.attributes["value"].value
+ if c != 1:
+ return True # failed to match
+ with rewriter.ip:
+ new_op = arith.constant(op.result.type, 2, loc=op.location)
+ rewriter.replace_op(op, [new_op])
+
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli)
+ patterns.add(arith.ConstantOp, constant_1_to_2)
frozen = patterns.freeze()
module = ModuleOp.parse(
@@ -48,3 +57,21 @@ def to_muli(op, rewriter, pattern):
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
# CHECK: return %0 : i64
print(module)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @const() -> (i64, i64) {
+ %0 = arith.constant 1 : i64
+ %1 = arith.constant 3 : i64
+ return %0, %1 : i64, i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: %c3_i64 = arith.constant 3 : i64
+ # CHECK: return %c2_i64, %c3_i64 : i64, i64
+ print(module)
>From a57961fc66c12529e957086869e008e835b70a54 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 11:47:17 +0800
Subject: [PATCH 07/10] fix format
---
mlir/test/python/rewrite.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index cbc3a4043f96c..4537068a5b9d5 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -31,7 +31,7 @@ def to_muli(op, rewriter, pattern):
def constant_1_to_2(op, rewriter, pattern):
c = op.attributes["value"].value
if c != 1:
- return True # failed to match
+ return True # failed to match
with rewriter.ip:
new_op = arith.constant(op.result.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])
>From 43da9a2cbe6d074fa863e6d500b18fc1d0a62894 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Fri, 10 Oct 2025 12:59:58 +0800
Subject: [PATCH 08/10] Update mlir/lib/Bindings/Python/Rewrite.cpp
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 07559457f2f2f..c938360756f03 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -250,7 +250,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
// clang-format on
- )
+ nb::arg("op"), nb::arg("values"))
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
// clang-format off
nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
>From 75c2dd90ae36c92ef184dda1a27150f5ace66aaf Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 13:06:55 +0800
Subject: [PATCH 09/10] reformat nb::sigs and add nb::args
---
mlir/lib/Bindings/Python/Rewrite.cpp | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index c938360756f03..078593955bf9c 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -233,10 +233,10 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"replace_op",
[](PyPatternRewriter &self, MlirOperation op,
MlirOperation newOp) { self.replaceOp(op, newOp); },
- "Replace an operation with a new operation.",
+ "Replace an operation with a new operation.", nb::arg("op"),
+ nb::arg("new_op"),
// clang-format off
- nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
- ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
// clang-format on
)
.def(
@@ -245,13 +245,14 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
const std::vector<MlirValue> &values) {
self.replaceOp(op, values);
},
- "Replace an operation with a list of values.",
+ "Replace an operation with a list of values.", nb::arg("op"),
+ nb::arg("values"),
// clang-format off
- nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
- ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
// clang-format on
- nb::arg("op"), nb::arg("values"))
+ )
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+ nb::arg("op"),
// clang-format off
nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
// clang-format on
>From 64d98e42960545330ee4842f8d81d12664f12784 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 13:09:16 +0800
Subject: [PATCH 10/10] remove log()
---
mlir/test/python/rewrite.py | 7 +------
1 file changed, 1 insertion(+), 6 deletions(-)
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 4537068a5b9d5..546a4fb720a98 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -8,13 +8,8 @@
from mlir.rewrite import *
-def log(*args):
- print(*args, file=sys.stderr)
- sys.stderr.flush()
-
-
def run(f):
- log("\nTEST:", f.__name__)
+ print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
More information about the Mlir-commits
mailing list