[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 3 07:48:04 PST 2026


================
@@ -249,96 +371,59 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
-class PyRewritePatternSet {
-public:
-  PyRewritePatternSet(MlirContext ctx)
-      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
-  ~PyRewritePatternSet() {
-    if (set.ptr)
-      mlirRewritePatternSetDestroy(set);
-  }
-
-  void add(MlirStringRef rootName, unsigned benefit,
-           const nb::callable &matchAndRewrite) {
-    MlirRewritePatternCallbacks callbacks;
-    callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
-    };
-    callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
-    };
-    callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
-                                   MlirPatternRewriter rewriter,
-                                   void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
-
-      PyMlirContextRef ctx =
-          PyMlirContext::forContext(mlirOperationGetContext(op));
-      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
-      nb::object res = f(opView, PyPatternRewriter(rewriter));
-      return logicalResultFromObject(res);
-    };
-    MlirRewritePattern pattern = mlirOpRewritePatternCreate(
-        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
-        /* nGeneratedNames */ 0,
-        /* generatedNames */ nullptr);
-    mlirRewritePatternSetAdd(set, pattern);
-  }
-
-  void addConversion(MlirStringRef rootName, unsigned benefit,
-                     const nb::callable &matchAndRewrite,
-                     PyTypeConverter &typeConverter) {
-    MlirConversionPatternCallbacks callbacks;
-    callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
-    };
-    callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
-    };
-    callbacks.matchAndRewrite =
-        [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
-           MlirValue *operands, MlirConversionPatternRewriter rewriter,
-           void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
-
-      PyMlirContextRef ctx =
-          PyMlirContext::forContext(mlirOperationGetContext(op));
-      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
-      std::vector<MlirValue> operandsVec(operands, operands + nOperands);
-      nb::object adaptorCls =
-          PyGlobals::get()
-              .lookupOpAdaptorClass([&] {
-                MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
-                return std::string_view(ref.data, ref.length);
-              }())
-              .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
-
-      nb::object res = f(opView, adaptorCls(operandsVec, opView),
-                         PyConversionPattern(pattern).getTypeConverter(),
-                         PyConversionPatternRewriter(rewriter));
-      return logicalResultFromObject(res);
-    };
-    MlirConversionPattern pattern = mlirOpConversionPatternCreate(
-        rootName, benefit, ctx, typeConverter.get(), callbacks,
-        matchAndRewrite.ptr(),
-        /* nGeneratedNames */ 0,
-        /* generatedNames */ nullptr);
-    mlirRewritePatternSetAdd(set,
-                             mlirConversionPatternAsRewritePattern(pattern));
-  }
-
-  PyFrozenRewritePatternSet freeze() {
-    MlirRewritePatternSet s = set;
-    set.ptr = nullptr;
-    return mlirFreezeRewritePattern(s);
-  }
+void PyRewritePatternSet::bind(nb::module_ &m) {
+  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+      .def(
+          "__init__",
+          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+            new (&self) PyRewritePatternSet(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+           nb::arg("benefit") = 1,
+           R"(Add a new rewrite pattern on the specified root operation, using
+              the provided callable for matching and rewriting, and assign it
+              the given benefit.
+
+              Args:
+                root: The root operation to which this pattern applies. This may
+                      be either an OpView subclass or an operation name.
+                fn: The callable to use for matching and rewriting, which takes
+                    an operation and a pattern rewriter. The match is considered
+                    successful iff the callable returns a falsy value.
+                benefit: The benefit of the pattern, defaulting to 1.)")
+      .def(
+          "add_conversion",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             PyTypeConverter &typeConverter, unsigned benefit) {
+            self.addConversion(root, benefit, fn, typeConverter);
+          },
----------------
PragmaTwice wrote:

Maybe just `&PyRewritePatternSet::addConversion`?

https://github.com/llvm/llvm-project/pull/184331


More information about the Mlir-commits mailing list