[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)
Rolf Morel
llvmlistbot at llvm.org
Tue Mar 3 04:46:02 PST 2026
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/184331
>From 77c7cb13a7f29ff3c9d7a1c323e5a452d8f3ef51 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 2 Mar 2026 13:43:56 -0800
Subject: [PATCH 1/2] [MLIR][Python][Transform] Expose
PatternDescriptorOpInterface to Python
Makes it possible to include Python-defined rewrite patterns in
transform-dialect schedules, inside of `transform.apply_patterns`, which
,upon execution of the schedule runs the pattern in a greedy rewriter.
---
mlir/include/mlir-c/Dialect/Transform.h | 32 ++
mlir/include/mlir-c/Rewrite.h | 4 +
mlir/lib/Bindings/Python/DialectTransform.cpp | 100 ++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 338 +++++++++---------
mlir/lib/Bindings/Python/Rewrite.h | 35 ++
mlir/lib/CAPI/Dialect/Transform.cpp | 83 +++++
mlir/lib/CAPI/Transforms/Rewrite.cpp | 4 +
...ansform_pattern_descriptor_op_interface.py | 114 ++++++
8 files changed, 548 insertions(+), 162 deletions(-)
create mode 100644 mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 43796a7f62727..cbda09cdbc37b 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -207,6 +207,38 @@ MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirTransformOpInterfaceCallbacks callbacks);
+//===---------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the PatternDescriptorOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void);
+
+/// Callbacks for implementing PatternDescriptorOpInterface from external code.
+typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Callback to populate rewrite patterns into the given pattern set.
+ void (*populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns,
+ void *userData);
+ /// Optional callback to populate rewrite patterns with transform state.
+ /// Set to nullptr to use the default implementation (calls populatePatterns).
+ void (*populatePatternsWithState)(MlirOperation op,
+ MlirRewritePatternSet patterns,
+ MlirTransformState state, void *userData);
+ void *userData;
+} MlirPatternDescriptorOpInterfaceCallbacks;
+
+/// Attach PatternDescriptorOpInterface to the operation with the given name
+/// using the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirPatternDescriptorOpInterfaceCallbacks callbacks);
+
//===---------------------------------------------------------------------===//
// Transform-specifc MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 6947158f624c0..2e12f9cabbddd 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -628,6 +628,10 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetCreate(MlirContext context);
+/// Get the context associated with a MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewritePatternSetGetContext(MlirRewritePatternSet set);
+
/// Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 157194f00e3c4..3b1b9abffb0f9 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,6 +11,7 @@
#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
+#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRInterfaces.h"
@@ -246,6 +247,104 @@ class PyTransformOpInterface
}
};
+//===----------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===----------------------------------------------------------------------===//
+class PyPatternDescriptorOpInterface
+ : public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
+public:
+ using PyConcreteOpInterface<
+ PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirPatternDescriptorOpInterfaceTypeID;
+
+ /// Attach a new PatternDescriptorOpInterface FallbackModel to the named
+ /// operation. The FallbackModel acts as a trampoline for callbacks on the
+ /// Python class.
+ static void attach(nb::object &target, const std::string &opName,
+ DefaultingPyMlirContext ctx) {
+ // Prepare the callbacks that will be used by the FallbackModel.
+ MlirPatternDescriptorOpInterfaceCallbacks callbacks;
+ // Make the pointer to the Python class available to the callbacks.
+ callbacks.userData = target.ptr();
+ nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+ // The above ref bump is all we need as initialization, no need to run the
+ // construct callback.
+ callbacks.construct = nullptr;
+ // Upon the FallbackModel's destruction, drop the ref to the Python class.
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+
+ // The populatePatterns callback which calls into Python.
+ callbacks.populatePatterns =
+ [](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyPopulatePatterns =
+ nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
+
+ auto pyPatterns = PyRewritePatternSet(patterns);
+
+ // Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
+ // staticmethod.
+ MlirContext ctx = mlirOperationGetContext(op);
+ PyMlirContextRef context = PyMlirContext::forContext(ctx);
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ pyPopulatePatterns(opview, pyPatterns);
+ };
+
+ // The populatePatternsWithState callback which calls into Python.
+ // Check if the Python class has populate_patterns_with_state method.
+ if (nb::hasattr(target, "populate_patterns_with_state")) {
+ callbacks.populatePatternsWithState = [](MlirOperation op,
+ MlirRewritePatternSet patterns,
+ MlirTransformState state,
+ void *userData) {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
+ nb::getattr(pyClass, "populate_patterns_with_state"));
+
+ auto pyPatterns = PyRewritePatternSet(patterns);
+ auto pyState = PyTransformState(state);
+
+ // Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
+ // state)` as a staticmethod.
+ MlirContext ctx = mlirOperationGetContext(op);
+ PyMlirContextRef context = PyMlirContext::forContext(ctx);
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ pyPopulatePatternsWithState(opview, pyPatterns, pyState);
+ };
+ } else {
+ // Use default implementation (will call populatePatterns).
+ callbacks.populatePatternsWithState = nullptr;
+ }
+
+ // Attach a FallbackModel, which calls into Python, to the named operation.
+ mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+ ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
+ callbacks);
+ }
+
+ static void bindDerived(ClassTy &cls) {
+ cls.attr("attach") = classmethod(
+ [](const nb::object &cls, const nb::object &opName, nb::object target,
+ DefaultingPyMlirContext context) {
+ if (target.is_none())
+ target = cls;
+ return attach(target, nb::cast<std::string>(opName), context);
+ },
+ nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
+ nb::arg("target").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ "Attach the interface subclass to the given operation name.");
+ }
+};
+
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -444,6 +543,7 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
PyTransformResults::bind(m);
PyTransformState::bind(m);
PyTransformOpInterface::bind(m);
+ PyPatternDescriptorOpInterface::bind(m);
m.def("only_reads_handle", onlyReadsHandle,
"Mark operands as only reading handles.", nb::arg("operands"),
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 181df847f36fe..dc5fc0699702b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,6 +35,75 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
+//===----------------------------------------------------------------------===//
+// PyRewritePatternSet
+//===----------------------------------------------------------------------===//
+
+PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
+ : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
+
+PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
+ : patterns(patterns), owned(false) {}
+
+PyRewritePatternSet::~PyRewritePatternSet() {
+ if (owned && patterns.ptr)
+ mlirRewritePatternSetDestroy(patterns);
+}
+
+MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
+
+bool PyRewritePatternSet::isOwned() const { return owned; }
+
+void PyRewritePatternSet::add(nb::handle root,
+ const nb::callable &matchAndRewrite,
+ unsigned benefit) {
+ std::string opName;
+ if (root.is_type()) {
+ opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ } else if (nb::isinstance<nb::str>(root)) {
+ opName = nb::cast<std::string>(root);
+ } else {
+ throw nb::type_error("the root argument must be a type or a string");
+ }
+
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(context, op)->createOpView();
+
+ nb::object res = f(opView, PyPatternRewriter(rewriter));
+
+ // The match is considered successful iff the callable returns
+ // a value where `bool(value)` is `False` (e.g. `None`).
+ if (res.is_none() || !nb::cast<bool>(res))
+ return mlirLogicalResultSuccess();
+ return mlirLogicalResultFailure();
+ };
+
+ MlirRewritePattern pattern = mlirOpRewritePatternCreate(
+ mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ mlirRewritePatternSetGetContext(patterns), callbacks,
+ matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(patterns, pattern);
+}
+
+//===----------------------------------------------------------------------===//
+// PyConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
class PyConversionPatternRewriter : public PyPatternRewriter {
public:
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
@@ -132,6 +201,59 @@ class PyConversionPattern {
MlirConversionPattern pattern;
};
+void PyRewritePatternSet::addConversion(nb::handle root, unsigned benefit,
+ const nb::callable &matchAndRewrite,
+ PyTypeConverter &typeConverter) {
+ std::string opName;
+ if (root.is_type()) {
+ opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ } else if (nb::isinstance<nb::str>(root)) {
+ opName = nb::cast<std::string>(root);
+ } else {
+ throw nb::type_error("the root argument must be a type or a string");
+ }
+ MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
+
+ MlirConversionPatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite =
+ [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
+ MlirValue *operands, MlirConversionPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+
+ PyMlirContextRef ctx =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+ std::vector<MlirValue> operandsVec(operands, operands + nOperands);
+ nb::object adaptorCls =
+ PyGlobals::get()
+ .lookupOpAdaptorClass([&] {
+ MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
+ return std::string_view(ref.data, ref.length);
+ }())
+ .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
+
+ nb::object res = f(opView, adaptorCls(operandsVec, opView),
+ PyConversionPattern(pattern).getTypeConverter(),
+ PyConversionPatternRewriter(rewriter));
+ return logicalResultFromObject(res);
+ };
+ MlirConversionPattern pattern = mlirOpConversionPatternCreate(
+ rootName, benefit, mlirRewritePatternSetGetContext(patterns),
+ typeConverter.get(), callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(patterns,
+ mlirConversionPatternAsRewritePattern(pattern));
+}
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
struct PyMlirPDLResultList : MlirPDLResultList {};
@@ -249,96 +371,60 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};
-class PyRewritePatternSet {
-public:
- PyRewritePatternSet(MlirContext ctx)
- : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
- ~PyRewritePatternSet() {
- if (set.ptr)
- mlirRewritePatternSetDestroy(set);
- }
-
- void add(MlirStringRef rootName, unsigned benefit,
- const nb::callable &matchAndRewrite) {
- MlirRewritePatternCallbacks callbacks;
- callbacks.construct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).inc_ref();
- };
- callbacks.destruct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).dec_ref();
- };
- callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
- MlirPatternRewriter rewriter,
- void *userData) -> MlirLogicalResult {
- nb::handle f(static_cast<PyObject *>(userData));
-
- PyMlirContextRef ctx =
- PyMlirContext::forContext(mlirOperationGetContext(op));
- nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
- nb::object res = f(opView, PyPatternRewriter(rewriter));
- return logicalResultFromObject(res);
- };
- MlirRewritePattern pattern = mlirOpRewritePatternCreate(
- rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
- /* nGeneratedNames */ 0,
- /* generatedNames */ nullptr);
- mlirRewritePatternSetAdd(set, pattern);
- }
-
- void addConversion(MlirStringRef rootName, unsigned benefit,
- const nb::callable &matchAndRewrite,
- PyTypeConverter &typeConverter) {
- MlirConversionPatternCallbacks callbacks;
- callbacks.construct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).inc_ref();
- };
- callbacks.destruct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).dec_ref();
- };
- callbacks.matchAndRewrite =
- [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
- MlirValue *operands, MlirConversionPatternRewriter rewriter,
- void *userData) -> MlirLogicalResult {
- nb::handle f(static_cast<PyObject *>(userData));
-
- PyMlirContextRef ctx =
- PyMlirContext::forContext(mlirOperationGetContext(op));
- nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
- std::vector<MlirValue> operandsVec(operands, operands + nOperands);
- nb::object adaptorCls =
- PyGlobals::get()
- .lookupOpAdaptorClass([&] {
- MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
- return std::string_view(ref.data, ref.length);
- }())
- .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
-
- nb::object res = f(opView, adaptorCls(operandsVec, opView),
- PyConversionPattern(pattern).getTypeConverter(),
- PyConversionPatternRewriter(rewriter));
- return logicalResultFromObject(res);
- };
- MlirConversionPattern pattern = mlirOpConversionPatternCreate(
- rootName, benefit, ctx, typeConverter.get(), callbacks,
- matchAndRewrite.ptr(),
- /* nGeneratedNames */ 0,
- /* generatedNames */ nullptr);
- mlirRewritePatternSetAdd(set,
- mlirConversionPatternAsRewritePattern(pattern));
- }
-
- PyFrozenRewritePatternSet freeze() {
- MlirRewritePatternSet s = set;
- set.ptr = nullptr;
- return mlirFreezeRewritePattern(s);
- }
+void PyRewritePatternSet::bind(nb::module_ &m) {
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+ nb::arg("benefit") = 1,
+ R"(Add a new rewrite pattern on the specified root operation, using
+ the provided callable for matching and rewriting, and assign it the
+ given benefit.
+
+ Args:
+ root: The root operation to which this pattern applies.
+ This may be either an OpView subclass or an operation name.
+ fn: The callable to use for matching and rewriting, which takes
+ an operation and a pattern rewriter. The match is considered
+ successful iff the callable returns a falsy value.
+ benefit: The benefit of the pattern, defaulting to 1.)")
+ .def(
+ "add_conversion",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ PyTypeConverter &typeConverter, unsigned benefit) {
+ self.addConversion(root, benefit, fn, typeConverter);
+ },
+ "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
+ R"(
+ Add a new conversion pattern on the specified root operation,
+ using the provided callable for matching and rewriting,
+ and assign it the given benefit.
-private:
- MlirRewritePatternSet set;
- MlirContext ctx;
-};
+ Args:
+ root: The root operation to which this pattern applies.
+ This may be either an OpView subclass or an operation name.
+ fn: The callable to use for matching and rewriting, which takes an
+ operation, its adaptor, the type converter and a pattern
+ rewriter. The match is considered successful iff the callable
+ returns a falsy value.
+ type_converter: The type converter to convert types in the IR.
+ benefit: The benefit of the pattern, defaulting to 1.)")
+ .def(
+ "freeze",
+ [](PyRewritePatternSet &self) {
+ if (!self.isOwned())
+ throw std::runtime_error(
+ "cannot freeze a non-owning pattern set");
+ MlirRewritePatternSet s = self.get();
+ return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s));
+ },
+ "Freeze the pattern set into a frozen one.");
+}
enum class PyGreedyRewriteStrictness : std::underlying_type_t<
MlirGreedyRewriteStrictness> {
@@ -505,79 +591,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
//----------------------------------------------------------------------------
- nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
- .def(
- "__init__",
- [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
- new (&self) PyRewritePatternSet(context.get()->get());
- },
- "context"_a = nb::none())
- .def(
- "add",
- [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
- unsigned benefit) {
- std::string opName;
- if (root.is_type()) {
- opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
- } else if (nb::isinstance<nb::str>(root)) {
- opName = nb::cast<std::string>(root);
- } else {
- throw nb::type_error(
- "the root argument must be a type or a string");
- }
- self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
- fn);
- },
- "root"_a, "fn"_a, "benefit"_a = 1,
- // clang-format off
- nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
- // clang-format on
- R"(
- Add a new rewrite pattern on the specified root operation, using the provided callable
- for matching and rewriting, and assign it the given benefit.
-
- Args:
- root: The root operation to which this pattern applies.
- This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
- an operation name string (e.g., ``"arith.addi"``).
- fn: The callable to use for matching and rewriting,
- which takes an operation and a pattern rewriter as arguments.
- The match is considered successful iff the callable returns
- a value where ``bool(value)`` is ``False`` (e.g. ``None``).
- If possible, the operation is cast to its corresponding OpView subclass
- before being passed to the callable.
- benefit: The benefit of the pattern, defaulting to 1.)")
- .def(
- "add_conversion",
- [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
- PyTypeConverter &typeConverter, unsigned benefit) {
- std::string opName =
- nb::cast<std::string>(root.attr("OPERATION_NAME"));
- self.addConversion(
- mlirStringRefCreate(opName.data(), opName.size()), benefit, fn,
- typeConverter);
- },
- "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
- R"(
- Add a new conversion pattern on the specified root operation,
- using the provided callable for matching and rewriting,
- and assign it the given benefit.
-
- Args:
- root: The root operation to which this pattern applies.
- This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
- an operation name string (e.g., ``"arith.addi"``).
- fn: The callable to use for matching and rewriting,
- which takes an operation, its adaptor,
- the type converter and a pattern rewriter as arguments.
- The match is considered successful iff the callable returns
- a value where ``bool(value)`` is ``False`` (e.g. ``None``).
- If possible, the operation is cast to its corresponding OpView subclass
- before being passed to the callable.
- type_converter: The type converter to convert types in the IR.
- benefit: The benefit of the pattern, defaulting to 1.)")
- .def("freeze", &PyRewritePatternSet::freeze,
- "Freeze the pattern set into a frozen one.");
+ PyRewritePatternSet::bind(m);
nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
m, "ConversionPatternRewriter")
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index 32d53f505c145..11a910dbf0bad 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -75,6 +75,41 @@ class MLIR_PYTHON_API_EXPORTED PyRewriterBase {
PyMlirContextRef ctx;
};
+/// Wrapper around MlirRewritePatternSet.
+/// The default constructor creates an owned pattern set that is destroyed
+/// in the destructor. The constructor taking MlirRewritePatternSet creates
+/// a non-owning reference.
+class PyTypeConverter;
+class MLIR_PYTHON_API_EXPORTED PyRewritePatternSet {
+public:
+ /// Create an owned pattern set.
+ PyRewritePatternSet(MlirContext ctx);
+
+ /// Create a non-owning reference to an existing pattern set.
+ PyRewritePatternSet(MlirRewritePatternSet patterns);
+
+ ~PyRewritePatternSet();
+
+ MlirRewritePatternSet get() const;
+
+ bool isOwned() const;
+
+ /// Add a new rewrite pattern to the pattern set.
+ void add(nanobind::handle root, const nanobind::callable &matchAndRewrite,
+ unsigned benefit);
+
+ /// Add a new conversion pattern to the pattern set.
+ void addConversion(nanobind::handle root, unsigned benefit,
+ const nanobind::callable &matchAndRewrite,
+ PyTypeConverter &typeConverter);
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ MlirRewritePatternSet patterns;
+ bool owned;
+};
+
void MLIR_PYTHON_API_EXPORTED populateRewriteSubmodule(nanobind::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 816e5df67e407..eb7b65467f519 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -298,6 +298,89 @@ void mlirTransformOpInterfaceAttachFallbackModel(
model->setCallbacks(callbacks);
}
+//===---------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===---------------------------------------------------------------------===//
+
+MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void) {
+ return wrap(transform::PatternDescriptorOpInterface::getInterfaceID());
+}
+
+/// Fallback model for the PatternDescriptorOpInterface that uses C API
+/// callbacks.
+class PatternDescriptorOpInterfaceFallbackModel
+ : public mlir::transform::PatternDescriptorOpInterface::FallbackModel<
+ PatternDescriptorOpInterfaceFallbackModel> {
+public:
+ /// Sets the callbacks that this FallbackModel will use.
+ /// NB: the callbacks can only be set through this method as the
+ /// RegisteredOperationName::attachInterface mechanism default-constructs
+ /// the FallbackModel without being able to provide arguments.
+ void setCallbacks(MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
+ this->callbacks = callbacks;
+ }
+
+ ~PatternDescriptorOpInterfaceFallbackModel() {
+ if (callbacks.destruct)
+ callbacks.destruct(callbacks.userData);
+ }
+
+ static TypeID getInterfaceID() {
+ return transform::PatternDescriptorOpInterface::getInterfaceID();
+ }
+
+ static bool classof(
+ const mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::
+ Concept *op) {
+ // Enable casting back to the FallbackModel from the Interface. This is
+ // necessary as attachInterface(...) default-constructs the FallbackModel
+ // without being able to pass in the callbacks and returns just the Concept.
+ return true;
+ }
+
+ void populatePatterns(Operation *op, RewritePatternSet &patterns) const {
+ assert(callbacks.populatePatterns && "populatePatterns callback not set");
+ callbacks.populatePatterns(wrap(op), wrap(&patterns), callbacks.userData);
+ }
+
+ void populatePatternsWithState(Operation *op, RewritePatternSet &patterns,
+ transform::TransformState &state) const {
+ if (callbacks.populatePatternsWithState) {
+ callbacks.populatePatternsWithState(wrap(op), wrap(&patterns),
+ wrap(&state), callbacks.userData);
+ } else {
+ // Default implementation: call populatePatterns without state.
+ populatePatterns(op, patterns);
+ }
+ }
+
+private:
+ MlirPatternDescriptorOpInterfaceCallbacks callbacks;
+};
+
+/// Attach a PatternDescriptorOpInterface FallbackModel to the given named
+/// operation. The FallbackModel uses the provided callbacks to implement the
+/// interface.
+void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
+ // Look up the operation definition in the context.
+ std::optional<RegisteredOperationName> opInfo =
+ RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
+
+ assert(opInfo.has_value() && "operation not found in context");
+
+ // NB: the following default-constructs the FallbackModel _without_ being able
+ // to provide arguments.
+ opInfo->attachInterface<PatternDescriptorOpInterfaceFallbackModel>();
+ // Cast to get the underlying FallbackModel and set the callbacks.
+ auto *model = cast<PatternDescriptorOpInterfaceFallbackModel>(
+ opInfo->getInterface<PatternDescriptorOpInterfaceFallbackModel>());
+
+ assert(model && "Failed to get PatternDescriptorOpInterfaceFallbackModel");
+ model->setCallbacks(callbacks);
+}
+
//===---------------------------------------------------------------------===//
// MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 5a6ade352d760..4f75cc758bc48 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -728,6 +728,10 @@ MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
return wrap(new mlir::RewritePatternSet(unwrap(context)));
}
+MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set) {
+ return wrap(unwrap(set)->getContext());
+}
+
void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
delete unwrap(set);
}
diff --git a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
new file mode 100644
index 0000000000000..470c679179b03
--- /dev/null
+++ b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
@@ -0,0 +1,114 @@
+# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
+
+from contextlib import contextmanager
+
+from mlir import ir, rewrite
+from mlir.dialects import transform, func, arith, ext
+from mlir.dialects.transform import AnyOpType, structured
+
+
+ at ext.register_dialect
+class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"):
+ pass
+
+
+def run(emit_schedule):
+ print(f"Test: {emit_schedule.__name__}")
+ with ir.Context(), ir.Location.unknown():
+ payload = emit_payload()
+
+ MyPatternDescriptors.load(register=False, reload=True)
+
+ # NB: Pattern descriptor ops have their interfaces attached
+ # in their respective test functions.
+ schedule = emit_schedule()
+
+ (_named_seq := schedule.body.operations[0]).apply(payload)
+
+ print(payload)
+
+
+# Payload used by all tests.
+def emit_payload():
+ payload_module = ir.Module.create()
+ with ir.InsertionPoint(payload_module.body):
+ i32 = ir.IntegerType.get_signless(32)
+
+ @func.FuncOp.from_py_func(i32, i32)
+ def test_func(a, b):
+ c = arith.addi(a, b)
+ d = arith.subi(c, b)
+ return d
+
+ return payload_module
+
+
+ at contextmanager
+def schedule_boilerplate():
+ schedule = ir.Module.create()
+ schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+ with ir.InsertionPoint(schedule.body):
+ named_sequence = transform.NamedSequenceOp(
+ "__transform_main",
+ [AnyOpType.get()],
+ [AnyOpType.get()],
+ arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
+ )
+ with ir.InsertionPoint(named_sequence.body):
+ yield schedule, named_sequence
+
+
+ at ext.register_operation(MyPatternDescriptors)
+class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"):
+ @classmethod
+ def attach_interface_impls(cls, ctx=None):
+ cls.PatternDescriptorOpInterfaceFallbackModel.attach(
+ cls.OPERATION_NAME, context=ctx
+ )
+
+ class PatternDescriptorOpInterfaceFallbackModel(
+ transform.PatternDescriptorOpInterface
+ ):
+ @staticmethod
+ def populate_patterns(
+ op: "SubiAddiRewritePatternOp",
+ patterns: rewrite.RewritePatternSet,
+ ) -> None:
+ # Define a pattern that rewrites subi(addi(a, b), b) -> a
+ def match_and_rewrite(subi, rewriter):
+ if not isinstance(addi := subi.lhs.owner, arith.AddiOp):
+ return True # Failed match, return truthy value
+ if subi.rhs != addi.rhs:
+ return True
+ # Replace subi's result with addi's lhs
+ rewriter.replace_op(subi, [addi.lhs])
+ return None # Success
+
+ # Add the pattern to the pattern set.
+ patterns.add("arith.subi", match_and_rewrite, benefit=1)
+
+
+# CHECK-LABEL: Test: test_pattern_descriptor_add_pattern
+ at run
+def test_pattern_descriptor_add_pattern():
+ """Tests python-defined rewrite pattern via PatternDescriptorOpInterface on AddPatternOp"""
+
+ SubiAddiRewritePatternOp.attach_interface_impls()
+
+ with schedule_boilerplate() as (schedule, named_seq):
+ func_handle = structured.MatchOp.match_op_names(
+ named_seq.bodyTarget, ["func.func"]
+ ).result
+
+ # After pattern application, check that subi is removed and func returns
+ # the first argument directly:
+ # CHECK: func.func @test_func(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+ # CHECK: return %[[ARG0]] : i32
+ apply_patterns_op = transform.ApplyPatternsOp(func_handle)
+ with ir.InsertionPoint(apply_patterns_op.patterns):
+ SubiAddiRewritePatternOp()
+
+ transform.yield_([func_handle])
+ named_seq.verify()
+
+ return schedule
>From 79534374769b2cac9e26dd5c58850f81a9b7e5e8 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 3 Mar 2026 04:45:38 -0800
Subject: [PATCH 2/2] Formatting
---
mlir/lib/Bindings/Python/Rewrite.cpp | 27 +++++++++++++--------------
mlir/lib/CAPI/Dialect/Transform.cpp | 6 +++---
2 files changed, 16 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index dc5fc0699702b..a1940b619e24d 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -379,20 +379,19 @@ void PyRewritePatternSet::bind(nb::module_ &m) {
new (&self) PyRewritePatternSet(context.get()->get());
},
"context"_a = nb::none())
- .def(
- "add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
- nb::arg("benefit") = 1,
- R"(Add a new rewrite pattern on the specified root operation, using
- the provided callable for matching and rewriting, and assign it the
- given benefit.
-
- Args:
- root: The root operation to which this pattern applies.
- This may be either an OpView subclass or an operation name.
- fn: The callable to use for matching and rewriting, which takes
- an operation and a pattern rewriter. The match is considered
- successful iff the callable returns a falsy value.
- benefit: The benefit of the pattern, defaulting to 1.)")
+ .def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+ nb::arg("benefit") = 1,
+ R"(Add a new rewrite pattern on the specified root operation, using
+ the provided callable for matching and rewriting, and assign it
+ the given benefit.
+
+ Args:
+ root: The root operation to which this pattern applies. This may
+ be either an OpView subclass or an operation name.
+ fn: The callable to use for matching and rewriting, which takes
+ an operation and a pattern rewriter. The match is considered
+ successful iff the callable returns a falsy value.
+ benefit: The benefit of the pattern, defaulting to 1.)")
.def(
"add_conversion",
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index eb7b65467f519..1ed14255bf5e0 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -329,9 +329,9 @@ class PatternDescriptorOpInterfaceFallbackModel
return transform::PatternDescriptorOpInterface::getInterfaceID();
}
- static bool classof(
- const mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::
- Concept *op) {
+ static bool
+ classof(const mlir::transform::detail::
+ PatternDescriptorOpInterfaceInterfaceTraits::Concept *op) {
// Enable casting back to the FallbackModel from the Interface. This is
// necessary as attachInterface(...) default-constructs the FallbackModel
// without being able to pass in the callbacks and returns just the Concept.
More information about the Mlir-commits
mailing list