[Mlir-commits] [mlir] [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (PR #176920)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 3 01:01:28 PST 2026


================
@@ -22,6 +24,219 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 namespace transform {
+
+//===----------------------------------------------------------------------===//
+// TransformRewriter
+//===----------------------------------------------------------------------===//
+class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
+public:
+  static constexpr const char *pyClassName = "TransformRewriter";
+
+  PyTransformRewriter(MlirTransformRewriter rewriter)
+      : PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
+};
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+class PyTransformResults {
+public:
+  PyTransformResults(MlirTransformResults results) : results(results) {}
+
+  MlirTransformResults get() const { return results; }
+
+  void setOps(MlirValue result, const nanobind::list &ops) {
+    std::vector<MlirOperation> opsVec;
+    opsVec.reserve(ops.size());
+    for (auto op : ops) {
+      opsVec.push_back(nb::cast<MlirOperation>(op));
+    }
+    mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
+  }
+
+  void setValues(MlirValue result, const nanobind::list &values) {
+    std::vector<MlirValue> valuesVec;
+    valuesVec.reserve(values.size());
+    for (auto item : values) {
+      valuesVec.push_back(nb::cast<MlirValue>(item));
+    }
+    mlirTransformResultsSetValues(results, result, valuesVec.size(),
+                                  valuesVec.data());
+  }
+
+  void setParams(MlirValue result, const nanobind::list &params) {
+    std::vector<MlirAttribute> paramsVec;
+    paramsVec.reserve(params.size());
+    for (auto item : params) {
+      paramsVec.push_back(nb::cast<MlirAttribute>(item));
+    }
+    mlirTransformResultsSetParams(results, result, paramsVec.size(),
+                                  paramsVec.data());
+  }
+
+  static void bind(nanobind::module_ &m) {
+    nb::class_<PyTransformResults>(m, "TransformResults")
+        .def(nb::init<MlirTransformResults>())
+        .def("set_ops", &PyTransformResults::setOps,
+             "Set the payload operations for a transform result.",
+             nb::arg("result"), nb::arg("ops"))
+        .def("set_values", &PyTransformResults::setValues,
+             "Set the payload values for a transform result.",
+             nb::arg("result"), nb::arg("values"))
+        .def("set_params", &PyTransformResults::setParams,
+             "Set the parameters for a transform result.", nb::arg("result"),
+             nb::arg("params"));
+  }
+
+private:
+  MlirTransformResults results;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+class PyTransformState {
+public:
+  PyTransformState(MlirTransformState state) : state(state) {}
+
+  MlirTransformState get() const { return state; }
+
+  nanobind::list getPayloadOps(MlirValue value) {
+    nanobind::list result;
+    mlirTransformStateForEachPayloadOp(
+        state, value,
+        [](MlirOperation op, void *userData) {
+          PyMlirContextRef context =
+              PyMlirContext::forContext(mlirOperationGetContext(op));
+          auto opview = PyOperation::forOperation(context, op)->createOpView();
+          static_cast<nanobind::list *>(userData)->append(opview);
+        },
+        &result);
+    return result;
+  }
+
+  nanobind::list getPayloadValues(MlirValue value) {
+    nanobind::list result;
+    mlirTransformStateForEachPayloadValue(
+        state, value,
+        [](MlirValue val, void *userData) {
+          static_cast<nanobind::list *>(userData)->append(val);
+        },
+        &result);
+    return result;
+  }
+
+  nanobind::list getParams(MlirValue value) {
+    nanobind::list result;
+    mlirTransformStateForEachParam(
+        state, value,
+        [](MlirAttribute attr, void *userData) {
+          static_cast<nanobind::list *>(userData)->append(attr);
+        },
+        &result);
+    return result;
+  }
+
+  static void bind(nanobind::module_ &m) {
+    nb::class_<PyTransformState>(m, "TransformState")
+        .def(nb::init<MlirTransformState>())
+        .def("get_payload_ops", &PyTransformState::getPayloadOps,
+             "Get the payload operations associated with a transform IR value.",
+             nb::arg("operand"))
+        .def("get_payload_values", &PyTransformState::getPayloadValues,
+             "Get the payload values associated with a transform IR value.",
+             nb::arg("operand"))
+        .def("get_params", &PyTransformState::getParams,
+             "Get the parameters (attributes) associated with a transform IR "
+             "value.",
+             nb::arg("operand"));
+  }
+
+private:
+  MlirTransformState state;
+};
+
+//===----------------------------------------------------------------------===//
+// TransformOpInterface
+//===----------------------------------------------------------------------===//
+class PyTransformOpInterface
+    : public PyConcreteOpInterface<PyTransformOpInterface> {
+public:
+  using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
+
+  constexpr static const char *pyClassName = "TransformOpInterface";
+  constexpr static GetTypeIDFunctionTy getInterfaceID =
+      &mlirTransformOpInterfaceTypeID;
+
+  /// Attach a new TransformOpInterface FallbackModel to the named operation.
+  /// The FallbackModel acts as a trampoline for callbacks on the Python class.
+  static void attach(nb::object &pyClass, const std::string &opName,
+                     DefaultingPyMlirContext ctx) {
+    // Prepare the callbacks that will be used by the FallbackModel.
+    MlirTransformOpInterfaceCallbacks callbacks;
+    // Make the pointer to the Python class available to the callbacks.
+    callbacks.userData = pyClass.ptr();
+    nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
+
+    // The above ref bump is all we need as initialization, no need to run the
+    // construct callback.
+    callbacks.construct = nullptr;
+    // Upon the FallbackModel's destruction, drop the ref to the Python class.
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    // The apply callback which calls into Python.
+    callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
+                         MlirTransformResults results, MlirTransformState state,
+                         void *userData) -> MlirDiagnosedSilenceableFailure {
+      nb::handle pyClass(static_cast<PyObject *>(userData));
+
+      auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
+
+      auto pyRewriter = PyTransformRewriter(rewriter);
+      auto pyResults = PyTransformResults(results);
+      auto pyState = PyTransformState(state);
+
+      // Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
+      // staticmethod.
+      PyMlirContextRef context =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      auto opview = PyOperation::forOperation(context, op)->createOpView();
+      nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
+
+      return nb::cast<MlirDiagnosedSilenceableFailure>(res);
+    };
+
+    // The allows_repeated_handle_operands callback which calls into Python.
+    callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
+                                                void *userData) -> bool {
+      nb::handle pyClass(static_cast<PyObject *>(userData));
+
+      auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
+          nb::getattr(pyClass, "allow_repeated_handle_operands"));
+
+      // Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
+      // staticmethod.
+      PyMlirContextRef context =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      auto opview = PyOperation::forOperation(context, op)->createOpView();
+      nb::object res = pyAllowRepeatedHandleOperands(opview);
+
+      return nb::cast<bool>(res);
+    };
+
+    // Attach a FallbackModel, which calls into Python, to the named operation.
+    mlirTransformOpInterfaceAttachFallbackModel(
+        ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
+  }
+
+  static void bindDerived(ClassTy &transformOpInterfaceClass) {
+    transformOpInterfaceClass.attr("attach") =
+        classmethod(&PyTransformOpInterface::attach, nb::arg("cls"),
+                    nb::arg("op_name"), nb::arg("ctx") = nb::none());
+  }
----------------
PragmaTwice wrote:

I'm wondering the difference between normal objects and class objects here? v.s.
```python
class XXXInterface:
  def allow_repeated_handle_operands(self, op):
    pass

XXXInterface(...).attach(XXXOp)
```

https://github.com/llvm/llvm-project/pull/176920


More information about the Mlir-commits mailing list