[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 07:47:42 PST 2026
================
@@ -35,6 +35,75 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
+//===----------------------------------------------------------------------===//
+// PyRewritePatternSet
+//===----------------------------------------------------------------------===//
+
+PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
+ : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
+
+PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
+ : patterns(patterns), owned(false) {}
+
+PyRewritePatternSet::~PyRewritePatternSet() {
+ if (owned && patterns.ptr)
+ mlirRewritePatternSetDestroy(patterns);
+}
+
+MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
+
+bool PyRewritePatternSet::isOwned() const { return owned; }
+
+void PyRewritePatternSet::add(nb::handle root,
+ const nb::callable &matchAndRewrite,
+ unsigned benefit) {
+ std::string opName;
+ if (root.is_type()) {
+ opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ } else if (nb::isinstance<nb::str>(root)) {
+ opName = nb::cast<std::string>(root);
+ } else {
+ throw nb::type_error("the root argument must be a type or a string");
+ }
+
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(context, op)->createOpView();
+
+ nb::object res = f(opView, PyPatternRewriter(rewriter));
+
+ // The match is considered successful iff the callable returns
+ // a value where `bool(value)` is `False` (e.g. `None`).
+ if (res.is_none() || !nb::cast<bool>(res))
+ return mlirLogicalResultSuccess();
+ return mlirLogicalResultFailure();
----------------
PragmaTwice wrote:
I think previously here we call `logicalResultFromObject` instead of this? It's better to reuse `logicalResultFromObject` when possible for maintanance.
https://github.com/llvm/llvm-project/pull/184331
More information about the Mlir-commits
mailing list