[Mlir-commits] [mlir] [MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (PR #184331)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 04:42:40 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rolf Morel (rolfmorel)
<details>
<summary>Changes</summary>
Makes it possible to include Python-defined rewrite patterns in transform-dialect schedules, inside of `transform.apply_patterns`, which upon execution of the schedule runs the pattern in a greedy rewriter.
---
Patch is 34.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184331.diff
8 Files Affected:
- (modified) mlir/include/mlir-c/Dialect/Transform.h (+32)
- (modified) mlir/include/mlir-c/Rewrite.h (+4)
- (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+100)
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+176-162)
- (modified) mlir/lib/Bindings/Python/Rewrite.h (+35)
- (modified) mlir/lib/CAPI/Dialect/Transform.cpp (+83)
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+4)
- (added) mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py (+114)
``````````diff
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 43796a7f62727..cbda09cdbc37b 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -207,6 +207,38 @@ MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirTransformOpInterfaceCallbacks callbacks);
+//===---------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===---------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the PatternDescriptorOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void);
+
+/// Callbacks for implementing PatternDescriptorOpInterface from external code.
+typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// Callback to populate rewrite patterns into the given pattern set.
+ void (*populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns,
+ void *userData);
+ /// Optional callback to populate rewrite patterns with transform state.
+ /// Set to nullptr to use the default implementation (calls populatePatterns).
+ void (*populatePatternsWithState)(MlirOperation op,
+ MlirRewritePatternSet patterns,
+ MlirTransformState state, void *userData);
+ void *userData;
+} MlirPatternDescriptorOpInterfaceCallbacks;
+
+/// Attach PatternDescriptorOpInterface to the operation with the given name
+/// using the provided callbacks.
+MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+ MlirContext ctx, MlirStringRef opName,
+ MlirPatternDescriptorOpInterfaceCallbacks callbacks);
+
//===---------------------------------------------------------------------===//
// Transform-specifc MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 6947158f624c0..2e12f9cabbddd 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -628,6 +628,10 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetCreate(MlirContext context);
+/// Get the context associated with a MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewritePatternSetGetContext(MlirRewritePatternSet set);
+
/// Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 157194f00e3c4..3b1b9abffb0f9 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,6 +11,7 @@
#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
+#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRInterfaces.h"
@@ -246,6 +247,104 @@ class PyTransformOpInterface
}
};
+//===----------------------------------------------------------------------===//
+// PatternDescriptorOpInterface
+//===----------------------------------------------------------------------===//
+class PyPatternDescriptorOpInterface
+ : public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
+public:
+ using PyConcreteOpInterface<
+ PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirPatternDescriptorOpInterfaceTypeID;
+
+ /// Attach a new PatternDescriptorOpInterface FallbackModel to the named
+ /// operation. The FallbackModel acts as a trampoline for callbacks on the
+ /// Python class.
+ static void attach(nb::object &target, const std::string &opName,
+ DefaultingPyMlirContext ctx) {
+ // Prepare the callbacks that will be used by the FallbackModel.
+ MlirPatternDescriptorOpInterfaceCallbacks callbacks;
+ // Make the pointer to the Python class available to the callbacks.
+ callbacks.userData = target.ptr();
+ nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+ // The above ref bump is all we need as initialization, no need to run the
+ // construct callback.
+ callbacks.construct = nullptr;
+ // Upon the FallbackModel's destruction, drop the ref to the Python class.
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+
+ // The populatePatterns callback which calls into Python.
+ callbacks.populatePatterns =
+ [](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyPopulatePatterns =
+ nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
+
+ auto pyPatterns = PyRewritePatternSet(patterns);
+
+ // Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
+ // staticmethod.
+ MlirContext ctx = mlirOperationGetContext(op);
+ PyMlirContextRef context = PyMlirContext::forContext(ctx);
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ pyPopulatePatterns(opview, pyPatterns);
+ };
+
+ // The populatePatternsWithState callback which calls into Python.
+ // Check if the Python class has populate_patterns_with_state method.
+ if (nb::hasattr(target, "populate_patterns_with_state")) {
+ callbacks.populatePatternsWithState = [](MlirOperation op,
+ MlirRewritePatternSet patterns,
+ MlirTransformState state,
+ void *userData) {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
+ nb::getattr(pyClass, "populate_patterns_with_state"));
+
+ auto pyPatterns = PyRewritePatternSet(patterns);
+ auto pyState = PyTransformState(state);
+
+ // Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
+ // state)` as a staticmethod.
+ MlirContext ctx = mlirOperationGetContext(op);
+ PyMlirContextRef context = PyMlirContext::forContext(ctx);
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ pyPopulatePatternsWithState(opview, pyPatterns, pyState);
+ };
+ } else {
+ // Use default implementation (will call populatePatterns).
+ callbacks.populatePatternsWithState = nullptr;
+ }
+
+ // Attach a FallbackModel, which calls into Python, to the named operation.
+ mlirPatternDescriptorOpInterfaceAttachFallbackModel(
+ ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
+ callbacks);
+ }
+
+ static void bindDerived(ClassTy &cls) {
+ cls.attr("attach") = classmethod(
+ [](const nb::object &cls, const nb::object &opName, nb::object target,
+ DefaultingPyMlirContext context) {
+ if (target.is_none())
+ target = cls;
+ return attach(target, nb::cast<std::string>(opName), context);
+ },
+ nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
+ nb::arg("target").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ "Attach the interface subclass to the given operation name.");
+ }
+};
+
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -444,6 +543,7 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
PyTransformResults::bind(m);
PyTransformState::bind(m);
PyTransformOpInterface::bind(m);
+ PyPatternDescriptorOpInterface::bind(m);
m.def("only_reads_handle", onlyReadsHandle,
"Mark operands as only reading handles.", nb::arg("operands"),
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 181df847f36fe..dc5fc0699702b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -35,6 +35,75 @@ class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
+//===----------------------------------------------------------------------===//
+// PyRewritePatternSet
+//===----------------------------------------------------------------------===//
+
+PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
+ : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
+
+PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
+ : patterns(patterns), owned(false) {}
+
+PyRewritePatternSet::~PyRewritePatternSet() {
+ if (owned && patterns.ptr)
+ mlirRewritePatternSetDestroy(patterns);
+}
+
+MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
+
+bool PyRewritePatternSet::isOwned() const { return owned; }
+
+void PyRewritePatternSet::add(nb::handle root,
+ const nb::callable &matchAndRewrite,
+ unsigned benefit) {
+ std::string opName;
+ if (root.is_type()) {
+ opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ } else if (nb::isinstance<nb::str>(root)) {
+ opName = nb::cast<std::string>(root);
+ } else {
+ throw nb::type_error("the root argument must be a type or a string");
+ }
+
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(context, op)->createOpView();
+
+ nb::object res = f(opView, PyPatternRewriter(rewriter));
+
+ // The match is considered successful iff the callable returns
+ // a value where `bool(value)` is `False` (e.g. `None`).
+ if (res.is_none() || !nb::cast<bool>(res))
+ return mlirLogicalResultSuccess();
+ return mlirLogicalResultFailure();
+ };
+
+ MlirRewritePattern pattern = mlirOpRewritePatternCreate(
+ mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ mlirRewritePatternSetGetContext(patterns), callbacks,
+ matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(patterns, pattern);
+}
+
+//===----------------------------------------------------------------------===//
+// PyConversionPatternRewriter
+//===----------------------------------------------------------------------===//
+
class PyConversionPatternRewriter : public PyPatternRewriter {
public:
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
@@ -132,6 +201,59 @@ class PyConversionPattern {
MlirConversionPattern pattern;
};
+void PyRewritePatternSet::addConversion(nb::handle root, unsigned benefit,
+ const nb::callable &matchAndRewrite,
+ PyTypeConverter &typeConverter) {
+ std::string opName;
+ if (root.is_type()) {
+ opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ } else if (nb::isinstance<nb::str>(root)) {
+ opName = nb::cast<std::string>(root);
+ } else {
+ throw nb::type_error("the root argument must be a type or a string");
+ }
+ MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
+
+ MlirConversionPatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite =
+ [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
+ MlirValue *operands, MlirConversionPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+
+ PyMlirContextRef ctx =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+ std::vector<MlirValue> operandsVec(operands, operands + nOperands);
+ nb::object adaptorCls =
+ PyGlobals::get()
+ .lookupOpAdaptorClass([&] {
+ MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
+ return std::string_view(ref.data, ref.length);
+ }())
+ .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
+
+ nb::object res = f(opView, adaptorCls(operandsVec, opView),
+ PyConversionPattern(pattern).getTypeConverter(),
+ PyConversionPatternRewriter(rewriter));
+ return logicalResultFromObject(res);
+ };
+ MlirConversionPattern pattern = mlirOpConversionPatternCreate(
+ rootName, benefit, mlirRewritePatternSetGetContext(patterns),
+ typeConverter.get(), callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(patterns,
+ mlirConversionPatternAsRewritePattern(pattern));
+}
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
struct PyMlirPDLResultList : MlirPDLResultList {};
@@ -249,96 +371,60 @@ class PyFrozenRewritePatternSet {
MlirFrozenRewritePatternSet set;
};
-class PyRewritePatternSet {
-public:
- PyRewritePatternSet(MlirContext ctx)
- : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
- ~PyRewritePatternSet() {
- if (set.ptr)
- mlirRewritePatternSetDestroy(set);
- }
-
- void add(MlirStringRef rootName, unsigned benefit,
- const nb::callable &matchAndRewrite) {
- MlirRewritePatternCallbacks callbacks;
- callbacks.construct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).inc_ref();
- };
- callbacks.destruct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).dec_ref();
- };
- callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
- MlirPatternRewriter rewriter,
- void *userData) -> MlirLogicalResult {
- nb::handle f(static_cast<PyObject *>(userData));
-
- PyMlirContextRef ctx =
- PyMlirContext::forContext(mlirOperationGetContext(op));
- nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
- nb::object res = f(opView, PyPatternRewriter(rewriter));
- return logicalResultFromObject(res);
- };
- MlirRewritePattern pattern = mlirOpRewritePatternCreate(
- rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
- /* nGeneratedNames */ 0,
- /* generatedNames */ nullptr);
- mlirRewritePatternSetAdd(set, pattern);
- }
-
- void addConversion(MlirStringRef rootName, unsigned benefit,
- const nb::callable &matchAndRewrite,
- PyTypeConverter &typeConverter) {
- MlirConversionPatternCallbacks callbacks;
- callbacks.construct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).inc_ref();
- };
- callbacks.destruct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).dec_ref();
- };
- callbacks.matchAndRewrite =
- [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
- MlirValue *operands, MlirConversionPatternRewriter rewriter,
- void *userData) -> MlirLogicalResult {
- nb::handle f(static_cast<PyObject *>(userData));
-
- PyMlirContextRef ctx =
- PyMlirContext::forContext(mlirOperationGetContext(op));
- nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
-
- std::vector<MlirValue> operandsVec(operands, operands + nOperands);
- nb::object adaptorCls =
- PyGlobals::get()
- .lookupOpAdaptorClass([&] {
- MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
- return std::string_view(ref.data, ref.length);
- }())
- .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
-
- nb::object res = f(opView, adaptorCls(operandsVec, opView),
- PyConversionPattern(pattern).getTypeConverter(),
- PyConversionPatternRewriter(rewriter));
- return logicalResultFromObject(res);
- };
- MlirConversionPattern pattern = mlirOpConversionPatternCreate(
- rootName, benefit, ctx, typeConverter.get(), callbacks,
- matchAndRewrite.ptr(),
- /* nGeneratedNames */ 0,
- /* generatedNames */ nullptr);
- mlirRewritePatternSetAdd(set,
- mlirConversionPatternAsRewritePattern(pattern));
- }
-
- PyFrozenRewritePatternSet freeze() {
- MlirRewritePatternSet s = set;
- set.ptr = nullptr;
- return mlirFreezeRewritePattern(s);
- }
+void PyRewritePatternSet::bind(nb::module_ &m) {
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
+ nb::arg("benefit") = 1,
+ R"(Add a new rewrite pattern on the specified root operation, using
+ the provided callable for matching and rewriting, and assign it the
+ given benefit.
+
+ Args:
+ root: The root operation to which this pattern applies.
+ This may be either an OpView subclass or an operation name.
+ fn: The callable to use for matching and rewriting, which takes
+ an operation and a pattern rewriter. The match is considered
+ successful iff the callable returns a falsy value.
+ benefit: The benefit of the pattern, defaulting to 1.)")
+ .def(
+ "add_conversion",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ PyTypeConverter &typeConverter, unsigned benefit) {
+ self.addConversion(root, benefit, fn, typeConverter);
+ },
+ "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
+ R"(
+ Add a new conversion pattern on the specified root operation,
+ using the provided callable for matching and rewriting,
+ and assign it the given benefit.
-private:
- MlirRewritePatternSet set;
- MlirContext ctx;
-};
+ Args:
+ root: The root operation to which this pattern applies.
+ This may be either an OpView subclass or an operation name.
+ fn: The callable to use for matching and rewriting, which takes an
+ operation, its adaptor, the type converter and a pattern
+ rewriter. The match is considered successful iff the callable
+ returns a falsy value.
+ type_converter: The type converter to convert types in the IR.
+ benefit: The benefit of the pattern, defaulting to 1.)")
+ .def(
+ "freeze",
+ [](PyRewritePatternSet &self) {
+ if (!self.isOwned())
+ thr...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184331
More information about the Mlir-commits
mailing list