[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 3 07:47:42 PST 2026


================
@@ -35,6 +35,75 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
       : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
 };
 
+//===----------------------------------------------------------------------===//
+// PyRewritePatternSet
+//===----------------------------------------------------------------------===//
+
+PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
+    : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
+
+PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
+    : patterns(patterns), owned(false) {}
+
+PyRewritePatternSet::~PyRewritePatternSet() {
+  if (owned && patterns.ptr)
+    mlirRewritePatternSetDestroy(patterns);
+}
+
+MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
+
+bool PyRewritePatternSet::isOwned() const { return owned; }
+
+void PyRewritePatternSet::add(nb::handle root,
+                              const nb::callable &matchAndRewrite,
+                              unsigned benefit) {
+  std::string opName;
+  if (root.is_type()) {
+    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+  } else if (nb::isinstance<nb::str>(root)) {
+    opName = nb::cast<std::string>(root);
+  } else {
+    throw nb::type_error("the root argument must be a type or a string");
+  }
+
+  MlirRewritePatternCallbacks callbacks;
+  callbacks.construct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+  };
+  callbacks.destruct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+  };
+  callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+                                 MlirPatternRewriter rewriter,
+                                 void *userData) -> MlirLogicalResult {
+    nb::handle f(static_cast<PyObject *>(userData));
+
+    PyMlirContextRef context =
+        PyMlirContext::forContext(mlirOperationGetContext(op));
+    nb::object opView = PyOperation::forOperation(context, op)->createOpView();
+
+    nb::object res = f(opView, PyPatternRewriter(rewriter));
+
+    // The match is considered successful iff the callable returns
+    // a value where `bool(value)` is `False` (e.g. `None`).
+    if (res.is_none() || !nb::cast<bool>(res))
+      return mlirLogicalResultSuccess();
+    return mlirLogicalResultFailure();
----------------
PragmaTwice wrote:

I think previously here we call `logicalResultFromObject` instead of this? It's better to reuse `logicalResultFromObject` when possible for maintanance.

https://github.com/llvm/llvm-project/pull/184331


More information about the Mlir-commits mailing list