[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 3 04:42:40 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

Makes it possible to include Python-defined rewrite patterns in transform-dialect schedules, inside of `transform.apply_patterns`, which upon execution of the schedule runs the pattern in a greedy rewriter.

---

Patch is 34.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184331.diff


8 Files Affected:

- (modified) mlir/include/mlir-c/Dialect/Transform.h (+32) 
- (modified) mlir/include/mlir-c/Rewrite.h (+4) 
- (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+100) 
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+176-162) 
- (modified) mlir/lib/Bindings/Python/Rewrite.h (+35) 
- (modified) mlir/lib/CAPI/Dialect/Transform.cpp (+83) 
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+4) 
- (added) mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py (+114) 


``````````diff
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 43796a7f62727..cbda09cdbc37b 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -207,6 +207,38 @@ MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
     MlirContext ctx, MlirStringRef opName,
     MlirTransformOpInterfaceCallbacks callbacks);
 
+//===---------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the PatternDescriptorOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void);
+
+/// Callbacks for implementing PatternDescriptorOpInterface from external code.
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// Callback to populate rewrite patterns into the given pattern set.
+  void (*populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns,
+                           void *userData);
+  /// Optional callback to populate rewrite patterns with transform state.
+  /// Set to nullptr to use the default implementation (calls populatePatterns).
+  void (*populatePatternsWithState)(MlirOperation op,
+                                    MlirRewritePatternSet patterns,
+                                    MlirTransformState state, void *userData);
+  void *userData;
+} MlirPatternDescriptorOpInterfaceCallbacks;
+
+/// Attach PatternDescriptorOpInterface to the operation with the given name
+/// using the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+    MlirContext ctx, MlirStringRef opName,
+    MlirPatternDescriptorOpInterfaceCallbacks callbacks);
+
 //===---------------------------------------------------------------------===//
 // Transform-specifc MemoryEffectsOpInterface helpers
 //===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 6947158f624c0..2e12f9cabbddd 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -628,6 +628,10 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetCreate(MlirContext context);
 
+/// Get the context associated with a MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewritePatternSetGetContext(MlirRewritePatternSet set);
+
 /// Destruct the given MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
 
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 157194f00e3c4..3b1b9abffb0f9 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,6 +11,7 @@
 #include "Rewrite.h"
 #include "mlir-c/Dialect/Transform.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/IRInterfaces.h"
@@ -246,6 +247,104 @@ class PyTransformOpInterface
   }
 };
 
+//===----------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===----------------------------------------------------------------------===//
+class PyPatternDescriptorOpInterface
+    : public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
+public:
+  using PyConcreteOpInterface<
+      PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
+
+  constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
+  constexpr static GetTypeIDFunctionTy getInterfaceID =
+      &mlirPatternDescriptorOpInterfaceTypeID;
+
+  /// Attach a new PatternDescriptorOpInterface FallbackModel to the named
+  /// operation. The FallbackModel acts as a trampoline for callbacks on the
+  /// Python class.
+  static void attach(nb::object &target, const std::string &opName,
+                     DefaultingPyMlirContext ctx) {
+    // Prepare the callbacks that will be used by the FallbackModel.
+    MlirPatternDescriptorOpInterfaceCallbacks callbacks;
+    // Make the pointer to the Python class available to the callbacks.
+    callbacks.userData = target.ptr();
+    nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+    // The above ref bump is all we need as initialization, no need to run the
+    // construct callback.
+    callbacks.construct = nullptr;
+    // Upon the FallbackModel's destruction, drop the ref to the Python class.
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+
+    // The populatePatterns callback which calls into Python.
+    callbacks.populatePatterns =
+        [](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
+          nb::handle pyClass(static_cast<PyObject *>(userData));
+
+          auto pyPopulatePatterns =
+              nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
+
+          auto pyPatterns = PyRewritePatternSet(patterns);
+
+          // Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
+          // staticmethod.
+          MlirContext ctx = mlirOperationGetContext(op);
+          PyMlirContextRef context = PyMlirContext::forContext(ctx);
+          auto opview = PyOperation::forOperation(context, op)->createOpView();
+          pyPopulatePatterns(opview, pyPatterns);
+        };
+
+    // The populatePatternsWithState callback which calls into Python.
+    // Check if the Python class has populate_patterns_with_state method.
+    if (nb::hasattr(target, "populate_patterns_with_state")) {
+      callbacks.populatePatternsWithState = [](MlirOperation op,
+                                               MlirRewritePatternSet patterns,
+                                               MlirTransformState state,
+                                               void *userData) {
+        nb::handle pyClass(static_cast<PyObject *>(userData));
+
+        auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
+            nb::getattr(pyClass, "populate_patterns_with_state"));
+
+        auto pyPatterns = PyRewritePatternSet(patterns);
+        auto pyState = PyTransformState(state);
+
+        // Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
+        // state)` as a staticmethod.
+        MlirContext ctx = mlirOperationGetContext(op);
+        PyMlirContextRef context = PyMlirContext::forContext(ctx);
+        auto opview = PyOperation::forOperation(context, op)->createOpView();
+        pyPopulatePatternsWithState(opview, pyPatterns, pyState);
+      };
+    } else {
+      // Use default implementation (will call populatePatterns).
+      callbacks.populatePatternsWithState = nullptr;
+    }
+
+    // Attach a FallbackModel, which calls into Python, to the named operation.
+    mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+        ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
+        callbacks);
+  }
+
+  static void bindDerived(ClassTy &cls) {
+    cls.attr("attach") = classmethod(
+        [](const nb::object &cls, const nb::object &opName, nb::object target,
+           DefaultingPyMlirContext context) {
+          if (target.is_none())
+            target = cls;
+          return attach(target, nb::cast<std::string>(opName), context);
+        },
+        nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
+        nb::arg("target").none() = nb::none(),
+        nb::arg("context").none() = nb::none(),
+        "Attach the interface subclass to the given operation name.");
+  }
+};
+
 //===-------------------------------------------------------------------===//
 // AnyOpType
 //===-------------------------------------------------------------------===//
@@ -444,6 +543,7 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
   PyTransformResults::bind(m);
   PyTransformState::bind(m);
   PyTransformOpInterface::bind(m);
+  PyPatternDescriptorOpInterface::bind(m);
 
   m.def("only_reads_handle", onlyReadsHandle,
         "Mark operands as only reading handles.", nb::arg("operands"),
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 181df847f36fe..dc5fc0699702b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,6 +35,75 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
       : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
 };
 
+//===----------------------------------------------------------------------===//
+// PyRewritePatternSet
+//===----------------------------------------------------------------------===//
+
+PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
+    : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
+
+PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
+    : patterns(patterns), owned(false) {}
+
+PyRewritePatternSet::~PyRewritePatternSet() {
+  if (owned && patterns.ptr)
+    mlirRewritePatternSetDestroy(patterns);
+}
+
+MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
+
+bool PyRewritePatternSet::isOwned() const { return owned; }
+
+void PyRewritePatternSet::add(nb::handle root,
+                              const nb::callable &matchAndRewrite,
+                              unsigned benefit) {
+  std::string opName;
+  if (root.is_type()) {
+    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+  } else if (nb::isinstance<nb::str>(root)) {
+    opName = nb::cast<std::string>(root);
+  } else {
+    throw nb::type_error("the root argument must be a type or a string");
+  }
+
+  MlirRewritePatternCallbacks callbacks;
+  callbacks.construct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+  };
+  callbacks.destruct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+  };
+  callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+                                 MlirPatternRewriter rewriter,
+                                 void *userData) -> MlirLogicalResult {
+    nb::handle f(static_cast<PyObject *>(userData));
+
+    PyMlirContextRef context =
+        PyMlirContext::forContext(mlirOperationGetContext(op));
+    nb::object opView = PyOperation::forOperation(context, op)->createOpView();
+
+    nb::object res = f(opView, PyPatternRewriter(rewriter));
+
+    // The match is considered successful iff the callable returns
+    // a value where `bool(value)` is `False` (e.g. `None`).
+    if (res.is_none() || !nb::cast<bool>(res))
+      return mlirLogicalResultSuccess();
+    return mlirLogicalResultFailure();
+  };
+
+  MlirRewritePattern pattern = mlirOpRewritePatternCreate(
+      mlirStringRefCreate(opName.data(), opName.size()), benefit,
+      mlirRewritePatternSetGetContext(patterns), callbacks,
+      matchAndRewrite.ptr(),
+      /* nGeneratedNames */ 0,
+      /* generatedNames */ nullptr);
+  mlirRewritePatternSetAdd(patterns, pattern);
+}
+
+//===----------------------------------------------------------------------===//
+// PyConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
 class PyConversionPatternRewriter : public PyPatternRewriter {
 public:
   PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
@@ -132,6 +201,59 @@ class PyConversionPattern {
   MlirConversionPattern pattern;
 };
 
+void PyRewritePatternSet::addConversion(nb::handle root, unsigned benefit,
+                                        const nb::callable &matchAndRewrite,
+                                        PyTypeConverter &typeConverter) {
+  std::string opName;
+  if (root.is_type()) {
+    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+  } else if (nb::isinstance<nb::str>(root)) {
+    opName = nb::cast<std::string>(root);
+  } else {
+    throw nb::type_error("the root argument must be a type or a string");
+  }
+  MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
+
+  MlirConversionPatternCallbacks callbacks;
+  callbacks.construct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+  };
+  callbacks.destruct = [](void *userData) {
+    nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+  };
+  callbacks.matchAndRewrite =
+      [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
+         MlirValue *operands, MlirConversionPatternRewriter rewriter,
+         void *userData) -> MlirLogicalResult {
+    nb::handle f(static_cast<PyObject *>(userData));
+
+    PyMlirContextRef ctx =
+        PyMlirContext::forContext(mlirOperationGetContext(op));
+    nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+    std::vector<MlirValue> operandsVec(operands, operands + nOperands);
+    nb::object adaptorCls =
+        PyGlobals::get()
+            .lookupOpAdaptorClass([&] {
+              MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
+              return std::string_view(ref.data, ref.length);
+            }())
+            .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
+
+    nb::object res = f(opView, adaptorCls(operandsVec, opView),
+                       PyConversionPattern(pattern).getTypeConverter(),
+                       PyConversionPatternRewriter(rewriter));
+    return logicalResultFromObject(res);
+  };
+  MlirConversionPattern pattern = mlirOpConversionPatternCreate(
+      rootName, benefit, mlirRewritePatternSetGetContext(patterns),
+      typeConverter.get(), callbacks, matchAndRewrite.ptr(),
+      /* nGeneratedNames */ 0,
+      /* generatedNames */ nullptr);
+  mlirRewritePatternSetAdd(patterns,
+                           mlirConversionPatternAsRewritePattern(pattern));
+}
+
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 struct PyMlirPDLResultList : MlirPDLResultList {};
 
@@ -249,96 +371,60 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
-class PyRewritePatternSet {
-public:
-  PyRewritePatternSet(MlirContext ctx)
-      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
-  ~PyRewritePatternSet() {
-    if (set.ptr)
-      mlirRewritePatternSetDestroy(set);
-  }
-
-  void add(MlirStringRef rootName, unsigned benefit,
-           const nb::callable &matchAndRewrite) {
-    MlirRewritePatternCallbacks callbacks;
-    callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
-    };
-    callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
-    };
-    callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
-                                   MlirPatternRewriter rewriter,
-                                   void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
-
-      PyMlirContextRef ctx =
-          PyMlirContext::forContext(mlirOperationGetContext(op));
-      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
-      nb::object res = f(opView, PyPatternRewriter(rewriter));
-      return logicalResultFromObject(res);
-    };
-    MlirRewritePattern pattern = mlirOpRewritePatternCreate(
-        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
-        /* nGeneratedNames */ 0,
-        /* generatedNames */ nullptr);
-    mlirRewritePatternSetAdd(set, pattern);
-  }
-
-  void addConversion(MlirStringRef rootName, unsigned benefit,
-                     const nb::callable &matchAndRewrite,
-                     PyTypeConverter &typeConverter) {
-    MlirConversionPatternCallbacks callbacks;
-    callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
-    };
-    callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
-    };
-    callbacks.matchAndRewrite =
-        [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
-           MlirValue *operands, MlirConversionPatternRewriter rewriter,
-           void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
-
-      PyMlirContextRef ctx =
-          PyMlirContext::forContext(mlirOperationGetContext(op));
-      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
-      std::vector<MlirValue> operandsVec(operands, operands + nOperands);
-      nb::object adaptorCls =
-          PyGlobals::get()
-              .lookupOpAdaptorClass([&] {
-                MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
-                return std::string_view(ref.data, ref.length);
-              }())
-              .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
-
-      nb::object res = f(opView, adaptorCls(operandsVec, opView),
-                         PyConversionPattern(pattern).getTypeConverter(),
-                         PyConversionPatternRewriter(rewriter));
-      return logicalResultFromObject(res);
-    };
-    MlirConversionPattern pattern = mlirOpConversionPatternCreate(
-        rootName, benefit, ctx, typeConverter.get(), callbacks,
-        matchAndRewrite.ptr(),
-        /* nGeneratedNames */ 0,
-        /* generatedNames */ nullptr);
-    mlirRewritePatternSetAdd(set,
-                             mlirConversionPatternAsRewritePattern(pattern));
-  }
-
-  PyFrozenRewritePatternSet freeze() {
-    MlirRewritePatternSet s = set;
-    set.ptr = nullptr;
-    return mlirFreezeRewritePattern(s);
-  }
+void PyRewritePatternSet::bind(nb::module_ &m) {
+  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+      .def(
+          "__init__",
+          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+            new (&self) PyRewritePatternSet(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def(
+          "add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+          nb::arg("benefit") = 1,
+          R"(Add a new rewrite pattern on the specified root operation, using
+             the provided callable for matching and rewriting, and assign it the
+             given benefit.
+
+             Args:
+               root: The root operation to which this pattern applies.
+                     This may be either an OpView subclass or an operation name.
+               fn: The callable to use for matching and rewriting, which takes
+                   an operation and a pattern rewriter. The match is considered
+                   successful iff the callable returns a falsy value.
+               benefit: The benefit of the pattern, defaulting to 1.)")
+      .def(
+          "add_conversion",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             PyTypeConverter &typeConverter, unsigned benefit) {
+            self.addConversion(root, benefit, fn, typeConverter);
+          },
+          "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
+          R"(
+            Add a new conversion pattern on the specified root operation,
+            using the provided callable for matching and rewriting,
+            and assign it the given benefit.
 
-private:
-  MlirRewritePatternSet set;
-  MlirContext ctx;
-};
+            Args:
+              root: The root operation to which this pattern applies.
+                    This may be either an OpView subclass or an operation name.
+              fn: The callable to use for matching and rewriting, which takes an
+                  operation, its adaptor, the type converter and a pattern
+                  rewriter. The match is considered successful iff the callable
+                  returns a falsy value.
+              type_converter: The type converter to convert types in the IR.
+              benefit: The benefit of the pattern, defaulting to 1.)")
+      .def(
+          "freeze",
+          [](PyRewritePatternSet &self) {
+            if (!self.isOwned())
+              thr...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184331


More information about the Mlir-commits mailing list