[Mlir-commits] [mlir] [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (PR #123953)

Peter Hawkins llvmlistbot at llvm.org
Thu Jan 23 13:17:43 PST 2025


================
@@ -1629,6 +1617,143 @@ void PyOperation::erase() {
   mlirOperationDestroy(operation);
 }
 
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  // and redefine bindDerived.
+  using ClassTy = nb::class_<DerivedTy, PyValue>;
+  using IsAFunctionTy = bool (*)(MlirValue);
+
+  PyConcreteValue() = default;
+  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+      : PyValue(operationRef, value) {}
+  PyConcreteValue(PyValue &orig)
+      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+  /// Attempts to cast the original value to the derived type and throws on
+  /// type mismatches.
+  static MlirValue castFrom(PyValue &orig) {
+    if (!DerivedTy::isaFunction(orig.get())) {
+      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+      throw nb::value_error((Twine("Cannot cast value to ") +
+                             DerivedTy::pyClassName + " (from " + origRepr +
+                             ")")
+                                .str()
+                                .c_str());
+    }
+    return orig.get();
+  }
+
+  /// Binds the Python module objects to functions of this class.
+  static void bind(nb::module_ &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+    cls.def_static(
+        "isinstance",
+        [](PyValue &otherValue) -> bool {
+          return DerivedTy::isaFunction(otherValue);
+        },
+        nb::arg("other_value"));
+    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+            [](DerivedTy &self) { return self.maybeDownCast(); });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+} // namespace
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+  static constexpr const char *pyClassName = "OpResult";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("owner", [](PyOpResult &self) {
+      assert(
+          mlirOperationEqual(self.getParentOperation()->get(),
+                             mlirOpResultGetOwner(self.get())) &&
+          "expected the owner of the value in Python to match that in the IR");
+      return self.getParentOperation().getObject();
+    });
+    c.def_prop_ro("result_number", [](PyOpResult &self) {
+      return mlirOpResultGetResultNumber(self.get());
+    });
+  }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<MlirType> getValueTypes(Container &container,
+                                           PyMlirContextRef &context) {
+  std::vector<MlirType> result;
+  result.reserve(container.size());
+  for (int i = 0, e = container.size(); i < e; ++i) {
+    result.push_back(mlirValueGetType(container.getElement(i).get()));
+  }
+  return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+  static constexpr const char *pyClassName = "OpResultList";
+  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+                 intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumResults(operation->get())
+                               : length,
+                  step),
+        operation(std::move(operation)) {}
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("types", [](PyOpResultList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+    c.def_prop_ro("owner", [](PyOpResultList &self) {
+      return self.operation->createOpView();
+    });
+  }
+
+  PyOperationRef &getOperation() { return operation; }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpResultList, PyOpResult>;
+
+  intptr_t getRawNumElements() {
+    operation->checkValid();
----------------
hawkinsp wrote:

PyOpResultList is now used by getOpResultOrValue.


https://github.com/llvm/llvm-project/pull/123953


More information about the Mlir-commits mailing list