[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)

Rolf Morel llvmlistbot at llvm.org
Tue Mar 3 04:46:02 PST 2026


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/184331

>From 77c7cb13a7f29ff3c9d7a1c323e5a452d8f3ef51 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 2 Mar 2026 13:43:56 -0800
Subject: [PATCH 1/2] [MLIR][Python][Transform] Expose
 PatternDescriptorOpInterface to Python

Makes it possible to include Python-defined rewrite patterns in
transform-dialect schedules, inside of `transform.apply_patterns`, which
,upon execution of the schedule runs the pattern in a greedy rewriter.
---
 mlir/include/mlir-c/Dialect/Transform.h       |  32 ++
 mlir/include/mlir-c/Rewrite.h                 |   4 +
 mlir/lib/Bindings/Python/DialectTransform.cpp | 100 ++++++
 mlir/lib/Bindings/Python/Rewrite.cpp          | 338 +++++++++---------
 mlir/lib/Bindings/Python/Rewrite.h            |  35 ++
 mlir/lib/CAPI/Dialect/Transform.cpp           |  83 +++++
 mlir/lib/CAPI/Transforms/Rewrite.cpp          |   4 +
 ...ansform_pattern_descriptor_op_interface.py | 114 ++++++
 8 files changed, 548 insertions(+), 162 deletions(-)
 create mode 100644 mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py

diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 43796a7f62727..cbda09cdbc37b 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -207,6 +207,38 @@ MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
     MlirContext ctx, MlirStringRef opName,
     MlirTransformOpInterfaceCallbacks callbacks);
 
+//===---------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the PatternDescriptorOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void);
+
+/// Callbacks for implementing PatternDescriptorOpInterface from external code.
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// Callback to populate rewrite patterns into the given pattern set.
+  void (*populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns,
+                           void *userData);
+  /// Optional callback to populate rewrite patterns with transform state.
+  /// Set to nullptr to use the default implementation (calls populatePatterns).
+  void (*populatePatternsWithState)(MlirOperation op,
+                                    MlirRewritePatternSet patterns,
+                                    MlirTransformState state, void *userData);
+  void *userData;
+} MlirPatternDescriptorOpInterfaceCallbacks;
+
+/// Attach PatternDescriptorOpInterface to the operation with the given name
+/// using the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirPatternDescriptorOpInterfaceCallbacks callbacks);
+
 //===---------------------------------------------------------------------===//
 // Transform-specifc MemoryEffectsOpInterface helpers
 //===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 6947158f624c0..2e12f9cabbddd 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -628,6 +628,10 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetCreate(MlirContext context);
 
+/// Get the context associated with a MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewritePatternSetGetContext(MlirRewritePatternSet set);
+
 /// Destruct the given MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
 
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 157194f00e3c4..3b1b9abffb0f9 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,6 +11,7 @@
 #include "Rewrite.h"
 #include "mlir-c/Dialect/Transform.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/IRInterfaces.h"
@@ -246,6 +247,104 @@ class PyTransformOpInterface
   }
 };
 
+//===----------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===----------------------------------------------------------------------===//
+class PyPatternDescriptorOpInterface
+    : public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
+public:
+  using PyConcreteOpInterface<
+      PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
+
+  constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
+  constexpr static GetTypeIDFunctionTy getInterfaceID =
+      &mlirPatternDescriptorOpInterfaceTypeID;
+
+  /// Attach a new PatternDescriptorOpInterface FallbackModel to the named
+  /// operation. The FallbackModel acts as a trampoline for callbacks on the
+  /// Python class.
+  static void attach(nb::object &target, const std::string &opName,
+                     DefaultingPyMlirContext ctx) {
+    // Prepare the callbacks that will be used by the FallbackModel.
+    MlirPatternDescriptorOpInterfaceCallbacks callbacks;
+    // Make the pointer to the Python class available to the callbacks.
+    callbacks.userData = target.ptr();
+    nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+    // The above ref bump is all we need as initialization, no need to run the
+    // construct callback.
+    callbacks.construct = nullptr;
+    // Upon the FallbackModel's destruction, drop the ref to the Python class.
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+
+    // The populatePatterns callback which calls into Python.
+    callbacks.populatePatterns =
+        [](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
+          nb::handle pyClass(static_cast<PyObject *>(userData));
+
+          auto pyPopulatePatterns =
+              nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
+
+          auto pyPatterns = PyRewritePatternSet(patterns);
+
+          // Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
+          // staticmethod.
+          MlirContext ctx = mlirOperationGetContext(op);
+          PyMlirContextRef context = PyMlirContext::forContext(ctx);
+          auto opview = PyOperation::forOperation(context, op)->createOpView();
+          pyPopulatePatterns(opview, pyPatterns);
+        };
+
+    // The populatePatternsWithState callback which calls into Python.
+    // Check if the Python class has populate_patterns_with_state method.
+    if (nb::hasattr(target, "populate_patterns_with_state")) {
+      callbacks.populatePatternsWithState = [](MlirOperation op,
+                                               MlirRewritePatternSet patterns,
+                                               MlirTransformState state,
+                                               void *userData) {
+        nb::handle pyClass(static_cast<PyObject *>(userData));
+
+        auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
+            nb::getattr(pyClass, "populate_patterns_with_state"));
+
+        auto pyPatterns = PyRewritePatternSet(patterns);
+        auto pyState = PyTransformState(state);
+
+        // Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
+        // state)` as a staticmethod.
+        MlirContext ctx = mlirOperationGetContext(op);
+        PyMlirContextRef context = PyMlirContext::forContext(ctx);
+        auto opview = PyOperation::forOperation(context, op)->createOpView();
+        pyPopulatePatternsWithState(opview, pyPatterns, pyState);
+      };
+    } else {
+      // Use default implementation (will call populatePatterns).
+      callbacks.populatePatternsWithState = nullptr;
+    }
+
+    // Attach a FallbackModel, which calls into Python, to the named operation.
+    mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+        ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
+        callbacks);
+  }
+
+  static void bindDerived(ClassTy &cls) {
+    cls.attr("attach") = classmethod(
+        [](const nb::object &cls, const nb::object &opName, nb::object target,
+           DefaultingPyMlirContext context) {
+          if (target.is_none())
+            target = cls;
+          return attach(target, nb::cast<std::string>(opName), context);
+        },
+        nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
+        nb::arg("target").none() = nb::none(),
+        nb::arg("context").none() = nb::none(),
+        "Attach the interface subclass to the given operation name.");
+  }
+};
+
 //===-------------------------------------------------------------------===//
 // AnyOpType
 //===-------------------------------------------------------------------===//
@@ -444,6 +543,7 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
   PyTransformResults::bind(m);
   PyTransformState::bind(m);
   PyTransformOpInterface::bind(m);
+  PyPatternDescriptorOpInterface::bind(m);
 
   m.def("only_reads_handle", onlyReadsHandle,
         "Mark operands as only reading handles.", nb::arg("operands"),
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 181df847f36fe..dc5fc0699702b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,6 +35,75 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
       : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
 };
 
+//===----------------------------------------------------------------------===//
+// PyRewritePatternSet
+//===----------------------------------------------------------------------===//
+
+PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
+    : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
+
+PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
+    : patterns(patterns), owned(false) {}
+
+PyRewritePatternSet::~PyRewritePatternSet() {
+  if (owned && patterns.ptr)
+    mlirRewritePatternSetDestroy(patterns);
+}
+
+MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
+
+bool PyRewritePatternSet::isOwned() const { return owned; }
+
+void PyRewritePatternSet::add(nb::handle root,
+                              const nb::callable &matchAndRewrite,
+                              unsigned benefit) {
+  std::string opName;
+  if (root.is_type()) {
+    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+  } else if (nb::isinstance<nb::str>(root)) {
+    opName = nb::cast<std::string>(root);
+  } else {
+    throw nb::type_error("the root argument must be a type or a string");
+  }
+
+  MlirRewritePatternCallbacks callbacks;
+  callbacks.construct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+  };
+  callbacks.destruct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+  };
+  callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+                                 MlirPatternRewriter rewriter,
+                                 void *userData) -> MlirLogicalResult {
+    nb::handle f(static_cast<PyObject *>(userData));
+
+    PyMlirContextRef context =
+        PyMlirContext::forContext(mlirOperationGetContext(op));
+    nb::object opView = PyOperation::forOperation(context, op)->createOpView();
+
+    nb::object res = f(opView, PyPatternRewriter(rewriter));
+
+    // The match is considered successful iff the callable returns
+    // a value where `bool(value)` is `False` (e.g. `None`).
+    if (res.is_none() || !nb::cast<bool>(res))
+      return mlirLogicalResultSuccess();
+    return mlirLogicalResultFailure();
+  };
+
+  MlirRewritePattern pattern = mlirOpRewritePatternCreate(
+      mlirStringRefCreate(opName.data(), opName.size()), benefit,
+      mlirRewritePatternSetGetContext(patterns), callbacks,
+      matchAndRewrite.ptr(),
+      /* nGeneratedNames */ 0,
+      /* generatedNames */ nullptr);
+  mlirRewritePatternSetAdd(patterns, pattern);
+}
+
+//===----------------------------------------------------------------------===//
+// PyConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
 class PyConversionPatternRewriter : public PyPatternRewriter {
 public:
   PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
@@ -132,6 +201,59 @@ class PyConversionPattern {
   MlirConversionPattern pattern;
 };
 
+void PyRewritePatternSet::addConversion(nb::handle root, unsigned benefit,
+                                        const nb::callable &matchAndRewrite,
+                                        PyTypeConverter &typeConverter) {
+  std::string opName;
+  if (root.is_type()) {
+    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+  } else if (nb::isinstance<nb::str>(root)) {
+    opName = nb::cast<std::string>(root);
+  } else {
+    throw nb::type_error("the root argument must be a type or a string");
+  }
+  MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
+
+  MlirConversionPatternCallbacks callbacks;
+  callbacks.construct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+  };
+  callbacks.destruct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+  };
+  callbacks.matchAndRewrite =
+      [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
+         MlirValue *operands, MlirConversionPatternRewriter rewriter,
+         void *userData) -> MlirLogicalResult {
+    nb::handle f(static_cast<PyObject *>(userData));
+
+    PyMlirContextRef ctx =
+        PyMlirContext::forContext(mlirOperationGetContext(op));
+    nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+    std::vector<MlirValue> operandsVec(operands, operands + nOperands);
+    nb::object adaptorCls =
+        PyGlobals::get()
+            .lookupOpAdaptorClass([&] {
+              MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
+              return std::string_view(ref.data, ref.length);
+            }())
+            .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
+
+    nb::object res = f(opView, adaptorCls(operandsVec, opView),
+                       PyConversionPattern(pattern).getTypeConverter(),
+                       PyConversionPatternRewriter(rewriter));
+    return logicalResultFromObject(res);
+  };
+  MlirConversionPattern pattern = mlirOpConversionPatternCreate(
+      rootName, benefit, mlirRewritePatternSetGetContext(patterns),
+      typeConverter.get(), callbacks, matchAndRewrite.ptr(),
+      /* nGeneratedNames */ 0,
+      /* generatedNames */ nullptr);
+  mlirRewritePatternSetAdd(patterns,
+                           mlirConversionPatternAsRewritePattern(pattern));
+}
+
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 struct PyMlirPDLResultList : MlirPDLResultList {};
 
@@ -249,96 +371,60 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
-class PyRewritePatternSet {
-public:
-  PyRewritePatternSet(MlirContext ctx)
-      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
-  ~PyRewritePatternSet() {
-    if (set.ptr)
-      mlirRewritePatternSetDestroy(set);
-  }
-
-  void add(MlirStringRef rootName, unsigned benefit,
-           const nb::callable &matchAndRewrite) {
-    MlirRewritePatternCallbacks callbacks;
-    callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
-    };
-    callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
-    };
-    callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
-                                   MlirPatternRewriter rewriter,
-                                   void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
-
-      PyMlirContextRef ctx =
-          PyMlirContext::forContext(mlirOperationGetContext(op));
-      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
-      nb::object res = f(opView, PyPatternRewriter(rewriter));
-      return logicalResultFromObject(res);
-    };
-    MlirRewritePattern pattern = mlirOpRewritePatternCreate(
-        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
-        /* nGeneratedNames */ 0,
-        /* generatedNames */ nullptr);
-    mlirRewritePatternSetAdd(set, pattern);
-  }
-
-  void addConversion(MlirStringRef rootName, unsigned benefit,
-                     const nb::callable &matchAndRewrite,
-                     PyTypeConverter &typeConverter) {
-    MlirConversionPatternCallbacks callbacks;
-    callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
-    };
-    callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
-    };
-    callbacks.matchAndRewrite =
-        [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
-           MlirValue *operands, MlirConversionPatternRewriter rewriter,
-           void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
-
-      PyMlirContextRef ctx =
-          PyMlirContext::forContext(mlirOperationGetContext(op));
-      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
-      std::vector<MlirValue> operandsVec(operands, operands + nOperands);
-      nb::object adaptorCls =
-          PyGlobals::get()
-              .lookupOpAdaptorClass([&] {
-                MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
-                return std::string_view(ref.data, ref.length);
-              }())
-              .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
-
-      nb::object res = f(opView, adaptorCls(operandsVec, opView),
-                         PyConversionPattern(pattern).getTypeConverter(),
-                         PyConversionPatternRewriter(rewriter));
-      return logicalResultFromObject(res);
-    };
-    MlirConversionPattern pattern = mlirOpConversionPatternCreate(
-        rootName, benefit, ctx, typeConverter.get(), callbacks,
-        matchAndRewrite.ptr(),
-        /* nGeneratedNames */ 0,
-        /* generatedNames */ nullptr);
-    mlirRewritePatternSetAdd(set,
-                             mlirConversionPatternAsRewritePattern(pattern));
-  }
-
-  PyFrozenRewritePatternSet freeze() {
-    MlirRewritePatternSet s = set;
-    set.ptr = nullptr;
-    return mlirFreezeRewritePattern(s);
-  }
+void PyRewritePatternSet::bind(nb::module_ &m) {
+  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+      .def(
+          "__init__",
+          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+            new (&self) PyRewritePatternSet(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def(
+          "add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+          nb::arg("benefit") = 1,
+          R"(Add a new rewrite pattern on the specified root operation, using
+             the provided callable for matching and rewriting, and assign it the
+             given benefit.
+
+             Args:
+               root: The root operation to which this pattern applies.
+                     This may be either an OpView subclass or an operation name.
+               fn: The callable to use for matching and rewriting, which takes
+                   an operation and a pattern rewriter. The match is considered
+                   successful iff the callable returns a falsy value.
+               benefit: The benefit of the pattern, defaulting to 1.)")
+      .def(
+          "add_conversion",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             PyTypeConverter &typeConverter, unsigned benefit) {
+            self.addConversion(root, benefit, fn, typeConverter);
+          },
+          "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
+          R"(
+            Add a new conversion pattern on the specified root operation,
+            using the provided callable for matching and rewriting,
+            and assign it the given benefit.
 
-private:
-  MlirRewritePatternSet set;
-  MlirContext ctx;
-};
+            Args:
+              root: The root operation to which this pattern applies.
+                    This may be either an OpView subclass or an operation name.
+              fn: The callable to use for matching and rewriting, which takes an
+                  operation, its adaptor, the type converter and a pattern
+                  rewriter. The match is considered successful iff the callable
+                  returns a falsy value.
+              type_converter: The type converter to convert types in the IR.
+              benefit: The benefit of the pattern, defaulting to 1.)")
+      .def(
+          "freeze",
+          [](PyRewritePatternSet &self) {
+            if (!self.isOwned())
+              throw std::runtime_error(
+                  "cannot freeze a non-owning pattern set");
+            MlirRewritePatternSet s = self.get();
+            return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s));
+          },
+          "Freeze the pattern set into a frozen one.");
+}
 
 enum class PyGreedyRewriteStrictness : std::underlying_type_t<
     MlirGreedyRewriteStrictness> {
@@ -505,79 +591,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
   // Mapping of the RewritePatternSet
   //----------------------------------------------------------------------------
-  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
-      .def(
-          "__init__",
-          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
-            new (&self) PyRewritePatternSet(context.get()->get());
-          },
-          "context"_a = nb::none())
-      .def(
-          "add",
-          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
-             unsigned benefit) {
-            std::string opName;
-            if (root.is_type()) {
-              opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
-            } else if (nb::isinstance<nb::str>(root)) {
-              opName = nb::cast<std::string>(root);
-            } else {
-              throw nb::type_error(
-                  "the root argument must be a type or a string");
-            }
-            self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
-                     fn);
-          },
-          "root"_a, "fn"_a, "benefit"_a = 1,
-          // clang-format off
-          nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
-          // clang-format on
-          R"(
-            Add a new rewrite pattern on the specified root operation, using the provided callable
-            for matching and rewriting, and assign it the given benefit.
-
-            Args:
-              root: The root operation to which this pattern applies.
-                    This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
-                    an operation name string (e.g., ``"arith.addi"``).
-              fn: The callable to use for matching and rewriting,
-                  which takes an operation and a pattern rewriter as arguments.
-                  The match is considered successful iff the callable returns
-                  a value where ``bool(value)`` is ``False`` (e.g. ``None``).
-                  If possible, the operation is cast to its corresponding OpView subclass
-                  before being passed to the callable.
-              benefit: The benefit of the pattern, defaulting to 1.)")
-      .def(
-          "add_conversion",
-          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
-             PyTypeConverter &typeConverter, unsigned benefit) {
-            std::string opName =
-                nb::cast<std::string>(root.attr("OPERATION_NAME"));
-            self.addConversion(
-                mlirStringRefCreate(opName.data(), opName.size()), benefit, fn,
-                typeConverter);
-          },
-          "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
-          R"(
-            Add a new conversion pattern on the specified root operation,
-            using the provided callable for matching and rewriting,
-            and assign it the given benefit.
-
-            Args:
-              root: The root operation to which this pattern applies.
-                    This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
-                    an operation name string (e.g., ``"arith.addi"``).
-              fn: The callable to use for matching and rewriting,
-                  which takes an operation, its adaptor,
-                  the type converter and a pattern rewriter as arguments.
-                  The match is considered successful iff the callable returns
-                  a value where ``bool(value)`` is ``False`` (e.g. ``None``).
-                  If possible, the operation is cast to its corresponding OpView subclass
-                  before being passed to the callable.
-              type_converter: The type converter to convert types in the IR.
-              benefit: The benefit of the pattern, defaulting to 1.)")
-      .def("freeze", &PyRewritePatternSet::freeze,
-           "Freeze the pattern set into a frozen one.");
+  PyRewritePatternSet::bind(m);
 
   nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
       m, "ConversionPatternRewriter")
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index 32d53f505c145..11a910dbf0bad 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -75,6 +75,41 @@ class MLIR_PYTHON_API_EXPORTED PyRewriterBase {
   PyMlirContextRef ctx;
 };
 
+/// Wrapper around MlirRewritePatternSet.
+/// The default constructor creates an owned pattern set that is destroyed
+/// in the destructor. The constructor taking MlirRewritePatternSet creates
+/// a non-owning reference.
+class PyTypeConverter;
+class MLIR_PYTHON_API_EXPORTED PyRewritePatternSet {
+public:
+  /// Create an owned pattern set.
+  PyRewritePatternSet(MlirContext ctx);
+
+  /// Create a non-owning reference to an existing pattern set.
+  PyRewritePatternSet(MlirRewritePatternSet patterns);
+
+  ~PyRewritePatternSet();
+
+  MlirRewritePatternSet get() const;
+
+  bool isOwned() const;
+
+  /// Add a new rewrite pattern to the pattern set.
+  void add(nanobind::handle root, const nanobind::callable &matchAndRewrite,
+           unsigned benefit);
+
+  /// Add a new conversion pattern to the pattern set.
+  void addConversion(nanobind::handle root, unsigned benefit,
+                     const nanobind::callable &matchAndRewrite,
+                     PyTypeConverter &typeConverter);
+
+  static void bind(nanobind::module_ &m);
+
+private:
+  MlirRewritePatternSet patterns;
+  bool owned;
+};
+
 void MLIR_PYTHON_API_EXPORTED populateRewriteSubmodule(nanobind::module_ &m);
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 816e5df67e407..eb7b65467f519 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -298,6 +298,89 @@ void mlirTransformOpInterfaceAttachFallbackModel(
   model->setCallbacks(callbacks);
 }
 
+//===---------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void) {
+  return wrap(transform::PatternDescriptorOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the PatternDescriptorOpInterface that uses C API
+/// callbacks.
+class PatternDescriptorOpInterfaceFallbackModel
+    : public mlir::transform::PatternDescriptorOpInterface::FallbackModel<
+          PatternDescriptorOpInterfaceFallbackModel> {
+public:
+  /// Sets the callbacks that this FallbackModel will use.
+  /// NB: the callbacks can only be set through this method as the
+  /// RegisteredOperationName::attachInterface mechanism default-constructs
+  /// the FallbackModel without being able to provide arguments.
+  void setCallbacks(MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
+    this->callbacks = callbacks;
+  }
+
+  ~PatternDescriptorOpInterfaceFallbackModel() {
+    if (callbacks.destruct)
+      callbacks.destruct(callbacks.userData);
+  }
+
+  static TypeID getInterfaceID() {
+    return transform::PatternDescriptorOpInterface::getInterfaceID();
+  }
+
+  static bool classof(
+      const mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::
+          Concept *op) {
+    // Enable casting back to the FallbackModel from the Interface. This is
+    // necessary as attachInterface(...) default-constructs the FallbackModel
+    // without being able to pass in the callbacks and returns just the Concept.
+    return true;
+  }
+
+  void populatePatterns(Operation *op, RewritePatternSet &patterns) const {
+    assert(callbacks.populatePatterns && "populatePatterns callback not set");
+    callbacks.populatePatterns(wrap(op), wrap(&patterns), callbacks.userData);
+  }
+
+  void populatePatternsWithState(Operation *op, RewritePatternSet &patterns,
+                                 transform::TransformState &state) const {
+    if (callbacks.populatePatternsWithState) {
+      callbacks.populatePatternsWithState(wrap(op), wrap(&patterns),
+                                          wrap(&state), callbacks.userData);
+    } else {
+      // Default implementation: call populatePatterns without state.
+      populatePatterns(op, patterns);
+    }
+  }
+
+private:
+  MlirPatternDescriptorOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a PatternDescriptorOpInterface FallbackModel to the given named
+/// operation. The FallbackModel uses the provided callbacks to implement the
+/// interface.
+void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
+  // Look up the operation definition in the context.
+  std::optional<RegisteredOperationName> opInfo =
+      RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+  assert(opInfo.has_value() && "operation not found in context");
+
+  // NB: the following default-constructs the FallbackModel _without_ being able
+  // to provide arguments.
+  opInfo->attachInterface<PatternDescriptorOpInterfaceFallbackModel>();
+  // Cast to get the underlying FallbackModel and set the callbacks.
+  auto *model = cast<PatternDescriptorOpInterfaceFallbackModel>(
+      opInfo->getInterface<PatternDescriptorOpInterfaceFallbackModel>());
+
+  assert(model && "Failed to get PatternDescriptorOpInterfaceFallbackModel");
+  model->setCallbacks(callbacks);
+}
+
 //===---------------------------------------------------------------------===//
 // MemoryEffectsOpInterface helpers
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 5a6ade352d760..4f75cc758bc48 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -728,6 +728,10 @@ MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
   return wrap(new mlir::RewritePatternSet(unwrap(context)));
 }
 
+MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set) {
+  return wrap(unwrap(set)->getContext());
+}
+
 void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
   delete unwrap(set);
 }
diff --git a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
new file mode 100644
index 0000000000000..470c679179b03
--- /dev/null
+++ b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
@@ -0,0 +1,114 @@
+# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
+
+from contextlib import contextmanager
+
+from mlir import ir, rewrite
+from mlir.dialects import transform, func, arith, ext
+from mlir.dialects.transform import AnyOpType, structured
+
+
+ at ext.register_dialect
+class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"):
+    pass
+
+
+def run(emit_schedule):
+    print(f"Test: {emit_schedule.__name__}")
+    with ir.Context(), ir.Location.unknown():
+        payload = emit_payload()
+
+        MyPatternDescriptors.load(register=False, reload=True)
+
+        # NB: Pattern descriptor ops have their interfaces attached
+        #     in their respective test functions.
+        schedule = emit_schedule()
+
+        (_named_seq := schedule.body.operations[0]).apply(payload)
+
+        print(payload)
+
+
+# Payload used by all tests.
+def emit_payload():
+    payload_module = ir.Module.create()
+    with ir.InsertionPoint(payload_module.body):
+        i32 = ir.IntegerType.get_signless(32)
+
+        @func.FuncOp.from_py_func(i32, i32)
+        def test_func(a, b):
+            c = arith.addi(a, b)
+            d = arith.subi(c, b)
+            return d
+
+    return payload_module
+
+
+ at contextmanager
+def schedule_boilerplate():
+    schedule = ir.Module.create()
+    schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+    with ir.InsertionPoint(schedule.body):
+        named_sequence = transform.NamedSequenceOp(
+            "__transform_main",
+            [AnyOpType.get()],
+            [AnyOpType.get()],
+            arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
+        )
+        with ir.InsertionPoint(named_sequence.body):
+            yield schedule, named_sequence
+
+
+ at ext.register_operation(MyPatternDescriptors)
+class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"):
+    @classmethod
+    def attach_interface_impls(cls, ctx=None):
+        cls.PatternDescriptorOpInterfaceFallbackModel.attach(
+            cls.OPERATION_NAME, context=ctx
+        )
+
+    class PatternDescriptorOpInterfaceFallbackModel(
+        transform.PatternDescriptorOpInterface
+    ):
+        @staticmethod
+        def populate_patterns(
+            op: "SubiAddiRewritePatternOp",
+            patterns: rewrite.RewritePatternSet,
+        ) -> None:
+            # Define a pattern that rewrites subi(addi(a, b), b) -> a
+            def match_and_rewrite(subi, rewriter):
+                if not isinstance(addi := subi.lhs.owner, arith.AddiOp):
+                    return True  # Failed match, return truthy value
+                if subi.rhs != addi.rhs:
+                    return True
+                # Replace subi's result with addi's lhs
+                rewriter.replace_op(subi, [addi.lhs])
+                return None  # Success
+
+            # Add the pattern to the pattern set.
+            patterns.add("arith.subi", match_and_rewrite, benefit=1)
+
+
+# CHECK-LABEL: Test: test_pattern_descriptor_add_pattern
+ at run
+def test_pattern_descriptor_add_pattern():
+    """Tests python-defined rewrite pattern via PatternDescriptorOpInterface on AddPatternOp"""
+
+    SubiAddiRewritePatternOp.attach_interface_impls()
+
+    with schedule_boilerplate() as (schedule, named_seq):
+        func_handle = structured.MatchOp.match_op_names(
+            named_seq.bodyTarget, ["func.func"]
+        ).result
+
+        # After pattern application, check that subi is removed and func returns
+        # the first argument directly:
+        # CHECK: func.func @test_func(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+        # CHECK: return %[[ARG0]] : i32
+        apply_patterns_op = transform.ApplyPatternsOp(func_handle)
+        with ir.InsertionPoint(apply_patterns_op.patterns):
+            SubiAddiRewritePatternOp()
+
+        transform.yield_([func_handle])
+        named_seq.verify()
+
+    return schedule

>From 79534374769b2cac9e26dd5c58850f81a9b7e5e8 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 3 Mar 2026 04:45:38 -0800
Subject: [PATCH 2/2] Formatting

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 27 +++++++++++++--------------
 mlir/lib/CAPI/Dialect/Transform.cpp  |  6 +++---
 2 files changed, 16 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index dc5fc0699702b..a1940b619e24d 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -379,20 +379,19 @@ void PyRewritePatternSet::bind(nb::module_ &m) {
             new (&self) PyRewritePatternSet(context.get()->get());
           },
           "context"_a = nb::none())
-      .def(
-          "add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
-          nb::arg("benefit") = 1,
-          R"(Add a new rewrite pattern on the specified root operation, using
-             the provided callable for matching and rewriting, and assign it the
-             given benefit.
-
-             Args:
-               root: The root operation to which this pattern applies.
-                     This may be either an OpView subclass or an operation name.
-               fn: The callable to use for matching and rewriting, which takes
-                   an operation and a pattern rewriter. The match is considered
-                   successful iff the callable returns a falsy value.
-               benefit: The benefit of the pattern, defaulting to 1.)")
+      .def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+           nb::arg("benefit") = 1,
+           R"(Add a new rewrite pattern on the specified root operation, using
+              the provided callable for matching and rewriting, and assign it
+              the given benefit.
+
+              Args:
+                root: The root operation to which this pattern applies. This may
+                      be either an OpView subclass or an operation name.
+                fn: The callable to use for matching and rewriting, which takes
+                    an operation and a pattern rewriter. The match is considered
+                    successful iff the callable returns a falsy value.
+                benefit: The benefit of the pattern, defaulting to 1.)")
       .def(
           "add_conversion",
           [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index eb7b65467f519..1ed14255bf5e0 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -329,9 +329,9 @@ class PatternDescriptorOpInterfaceFallbackModel
     return transform::PatternDescriptorOpInterface::getInterfaceID();
   }
 
-  static bool classof(
-      const mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::
-          Concept *op) {
+  static bool
+  classof(const mlir::transform::detail::
+              PatternDescriptorOpInterfaceInterfaceTraits::Concept *op) {
     // Enable casting back to the FallbackModel from the Interface. This is
     // necessary as attachInterface(...) default-constructs the FallbackModel
     // without being able to pass in the callbacks and returns just the Concept.



More information about the Mlir-commits mailing list