[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 3 01:01:28 PST 2026
================
@@ -22,6 +24,219 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace transform {
+
+//===----------------------------------------------------------------------===//
+// TransformRewriter
+//===----------------------------------------------------------------------===//
+class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
+public:
+ static constexpr const char *pyClassName = "TransformRewriter";
+
+ PyTransformRewriter(MlirTransformRewriter rewriter)
+ : PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
+};
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+class PyTransformResults {
+public:
+ PyTransformResults(MlirTransformResults results) : results(results) {}
+
+ MlirTransformResults get() const { return results; }
+
+ void setOps(MlirValue result, const nanobind::list &ops) {
+ std::vector<MlirOperation> opsVec;
+ opsVec.reserve(ops.size());
+ for (auto op : ops) {
+ opsVec.push_back(nb::cast<MlirOperation>(op));
+ }
+ mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
+ }
+
+ void setValues(MlirValue result, const nanobind::list &values) {
+ std::vector<MlirValue> valuesVec;
+ valuesVec.reserve(values.size());
+ for (auto item : values) {
+ valuesVec.push_back(nb::cast<MlirValue>(item));
+ }
+ mlirTransformResultsSetValues(results, result, valuesVec.size(),
+ valuesVec.data());
+ }
+
+ void setParams(MlirValue result, const nanobind::list ¶ms) {
+ std::vector<MlirAttribute> paramsVec;
+ paramsVec.reserve(params.size());
+ for (auto item : params) {
+ paramsVec.push_back(nb::cast<MlirAttribute>(item));
+ }
+ mlirTransformResultsSetParams(results, result, paramsVec.size(),
+ paramsVec.data());
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<PyTransformResults>(m, "TransformResults")
+ .def(nb::init<MlirTransformResults>())
+ .def("set_ops", &PyTransformResults::setOps,
+ "Set the payload operations for a transform result.",
+ nb::arg("result"), nb::arg("ops"))
+ .def("set_values", &PyTransformResults::setValues,
+ "Set the payload values for a transform result.",
+ nb::arg("result"), nb::arg("values"))
+ .def("set_params", &PyTransformResults::setParams,
+ "Set the parameters for a transform result.", nb::arg("result"),
+ nb::arg("params"));
+ }
+
+private:
+ MlirTransformResults results;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+class PyTransformState {
+public:
+ PyTransformState(MlirTransformState state) : state(state) {}
+
+ MlirTransformState get() const { return state; }
+
+ nanobind::list getPayloadOps(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachPayloadOp(
+ state, value,
+ [](MlirOperation op, void *userData) {
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ static_cast<nanobind::list *>(userData)->append(opview);
+ },
+ &result);
+ return result;
+ }
+
+ nanobind::list getPayloadValues(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachPayloadValue(
+ state, value,
+ [](MlirValue val, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(val);
+ },
+ &result);
+ return result;
+ }
+
+ nanobind::list getParams(MlirValue value) {
+ nanobind::list result;
+ mlirTransformStateForEachParam(
+ state, value,
+ [](MlirAttribute attr, void *userData) {
+ static_cast<nanobind::list *>(userData)->append(attr);
+ },
+ &result);
+ return result;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nb::class_<PyTransformState>(m, "TransformState")
+ .def(nb::init<MlirTransformState>())
+ .def("get_payload_ops", &PyTransformState::getPayloadOps,
+ "Get the payload operations associated with a transform IR value.",
+ nb::arg("operand"))
+ .def("get_payload_values", &PyTransformState::getPayloadValues,
+ "Get the payload values associated with a transform IR value.",
+ nb::arg("operand"))
+ .def("get_params", &PyTransformState::getParams,
+ "Get the parameters (attributes) associated with a transform IR "
+ "value.",
+ nb::arg("operand"));
+ }
+
+private:
+ MlirTransformState state;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformOpInterface
+//===----------------------------------------------------------------------===//
+class PyTransformOpInterface
+ : public PyConcreteOpInterface<PyTransformOpInterface> {
+public:
+ using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
+
+ constexpr static const char *pyClassName = "TransformOpInterface";
+ constexpr static GetTypeIDFunctionTy getInterfaceID =
+ &mlirTransformOpInterfaceTypeID;
+
+ /// Attach a new TransformOpInterface FallbackModel to the named operation.
+ /// The FallbackModel acts as a trampoline for callbacks on the Python class.
+ static void attach(nb::object &pyClass, const std::string &opName,
+ DefaultingPyMlirContext ctx) {
+ // Prepare the callbacks that will be used by the FallbackModel.
+ MlirTransformOpInterfaceCallbacks callbacks;
+ // Make the pointer to the Python class available to the callbacks.
+ callbacks.userData = pyClass.ptr();
+ nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+ // The above ref bump is all we need as initialization, no need to run the
+ // construct callback.
+ callbacks.construct = nullptr;
+ // Upon the FallbackModel's destruction, drop the ref to the Python class.
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ // The apply callback which calls into Python.
+ callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
+ MlirTransformResults results, MlirTransformState state,
+ void *userData) -> MlirDiagnosedSilenceableFailure {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
+
+ auto pyRewriter = PyTransformRewriter(rewriter);
+ auto pyResults = PyTransformResults(results);
+ auto pyState = PyTransformState(state);
+
+ // Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
+ // staticmethod.
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
+
+ return nb::cast<MlirDiagnosedSilenceableFailure>(res);
+ };
+
+ // The allows_repeated_handle_operands callback which calls into Python.
+ callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
+ void *userData) -> bool {
+ nb::handle pyClass(static_cast<PyObject *>(userData));
+
+ auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
+ nb::getattr(pyClass, "allow_repeated_handle_operands"));
+
+ // Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
+ // staticmethod.
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ auto opview = PyOperation::forOperation(context, op)->createOpView();
+ nb::object res = pyAllowRepeatedHandleOperands(opview);
+
+ return nb::cast<bool>(res);
+ };
+
+ // Attach a FallbackModel, which calls into Python, to the named operation.
+ mlirTransformOpInterfaceAttachFallbackModel(
+ ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+ }
+
+ static void bindDerived(ClassTy &transformOpInterfaceClass) {
+ transformOpInterfaceClass.attr("attach") =
+ classmethod(&PyTransformOpInterface::attach, nb::arg("cls"),
+ nb::arg("op_name"), nb::arg("ctx") = nb::none());
+ }
----------------
PragmaTwice wrote:
I'm wondering the difference between normal objects and class objects here? v.s.
```python
class XXXInterface:
def allow_repeated_handle_operands(self, op):
pass
XXXInterface(...).attach(XXXOp)
```
https://github.com/llvm/llvm-project/pull/176920
More information about the Mlir-commits
mailing list