[Mlir-commits] [mlir] [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (PR #123953)
Peter Hawkins
llvmlistbot at llvm.org
Thu Jan 23 13:17:43 PST 2025
================
@@ -1629,6 +1617,143 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nb::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+ throw nb::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " + origRepr +
+ ")")
+ .str()
+ .c_str());
+ }
+ return orig.get();
+ }
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nb::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nb::arg("other_value"));
+ cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) { return self.maybeDownCast(); });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+} // namespace
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro("owner", [](PyOpResult &self) {
+ assert(
+ mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in the IR");
+ return self.getParentOperation().getObject();
+ });
+ c.def_prop_ro("result_number", [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ });
+ }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<MlirType> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<MlirType> result;
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
+ result.push_back(mlirValueGetType(container.getElement(i).get()));
+ }
+ return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+ static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+ PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro("types", [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ c.def_prop_ro("owner", [](PyOpResultList &self) {
+ return self.operation->createOpView();
+ });
+ }
+
+ PyOperationRef &getOperation() { return operation; }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
----------------
hawkinsp wrote:
PyOpResultList is now used by getOpResultOrValue.
https://github.com/llvm/llvm-project/pull/123953
More information about the Mlir-commits
mailing list