[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 07:52:42 PST 2026
================
@@ -249,96 +371,59 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};
-class PyRewritePatternSet {
-public:
- PyRewritePatternSet(MlirContext ctx)
- : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
- ~PyRewritePatternSet() {
- if (set.ptr)
- mlirRewritePatternSetDestroy(set);
- }
-
- void add(MlirStringRef rootName, unsigned benefit,
- const nb::callable &matchAndRewrite) {
- MlirRewritePatternCallbacks callbacks;
- callbacks.construct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).inc_ref();
- };
- callbacks.destruct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).dec_ref();
- };
- callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
- MlirPatternRewriter rewriter,
- void *userData) -> MlirLogicalResult {
- nb::handle f(static_cast<PyObject *>(userData));
-
- PyMlirContextRef ctx =
- PyMlirContext::forContext(mlirOperationGetContext(op));
- nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
- nb::object res = f(opView, PyPatternRewriter(rewriter));
- return logicalResultFromObject(res);
- };
- MlirRewritePattern pattern = mlirOpRewritePatternCreate(
- rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
- /* nGeneratedNames */ 0,
- /* generatedNames */ nullptr);
- mlirRewritePatternSetAdd(set, pattern);
- }
-
- void addConversion(MlirStringRef rootName, unsigned benefit,
- const nb::callable &matchAndRewrite,
- PyTypeConverter &typeConverter) {
- MlirConversionPatternCallbacks callbacks;
- callbacks.construct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).inc_ref();
- };
- callbacks.destruct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).dec_ref();
- };
- callbacks.matchAndRewrite =
- [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
- MlirValue *operands, MlirConversionPatternRewriter rewriter,
- void *userData) -> MlirLogicalResult {
- nb::handle f(static_cast<PyObject *>(userData));
-
- PyMlirContextRef ctx =
- PyMlirContext::forContext(mlirOperationGetContext(op));
- nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
- std::vector<MlirValue> operandsVec(operands, operands + nOperands);
- nb::object adaptorCls =
- PyGlobals::get()
- .lookupOpAdaptorClass([&] {
- MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
- return std::string_view(ref.data, ref.length);
- }())
- .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
-
- nb::object res = f(opView, adaptorCls(operandsVec, opView),
- PyConversionPattern(pattern).getTypeConverter(),
- PyConversionPatternRewriter(rewriter));
- return logicalResultFromObject(res);
- };
- MlirConversionPattern pattern = mlirOpConversionPatternCreate(
- rootName, benefit, ctx, typeConverter.get(), callbacks,
- matchAndRewrite.ptr(),
- /* nGeneratedNames */ 0,
- /* generatedNames */ nullptr);
- mlirRewritePatternSetAdd(set,
- mlirConversionPatternAsRewritePattern(pattern));
- }
-
- PyFrozenRewritePatternSet freeze() {
- MlirRewritePatternSet s = set;
- set.ptr = nullptr;
- return mlirFreezeRewritePattern(s);
- }
+void PyRewritePatternSet::bind(nb::module_ &m) {
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+ nb::arg("benefit") = 1,
+ R"(Add a new rewrite pattern on the specified root operation, using
+ the provided callable for matching and rewriting, and assign it
+ the given benefit.
+
+ Args:
+ root: The root operation to which this pattern applies. This may
+ be either an OpView subclass or an operation name.
+ fn: The callable to use for matching and rewriting, which takes
+ an operation and a pattern rewriter. The match is considered
+ successful iff the callable returns a falsy value.
+ benefit: The benefit of the pattern, defaulting to 1.)")
+ .def(
+ "add_conversion",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ PyTypeConverter &typeConverter, unsigned benefit) {
+ self.addConversion(root, benefit, fn, typeConverter);
+ },
----------------
PragmaTwice wrote:
Ahh the parameter order is different?
https://github.com/llvm/llvm-project/pull/184331
More information about the Mlir-commits
mailing list