[Mlir-commits] [mlir] [mlir][python] value casting (PR #69644)
Jacques Pienaar
llvmlistbot at llvm.org
Fri Nov 3 10:51:11 PDT 2023
================
@@ -65,16 +67,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
encoding));
},
"cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
- assert(py::hasattr(cls.get_class(), "static_typeid") &&
+
+ assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
"TestIntegerRankedTensorType has no static_typeid");
- MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+
+ MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
- return cls.get_class()(mlirType);
+ mlirRankedTensorTypeID,
+ pybind11::cpp_function([typeCls](const py::object &mlirType) {
+ return typeCls.get_class()(mlirType);
}),
/*replace=*/true);
- mlir_value_subclass(m, "TestTensorValue",
- mlirTypeIsAPythonTestTestTensorValue)
- .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+
+ auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+ mlirTypeIsAPythonTestTestTensorValue)
+ .def("is_null", [](MlirValue &self) {
+ return mlirValueIsNull(self);
+ });
+
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+ mlirRankedTensorTypeID)(
+ pybind11::cpp_function([valueCls](const py::object &valueObj) {
+ py::object capsule = mlirApiObjectToCapsule(valueObj);
+ MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+ MlirType t = mlirValueGetType(v);
+ if (mlirShapedTypeHasStaticShape(t) &&
----------------
jpienaar wrote:
Comment? 🙂
https://github.com/llvm/llvm-project/pull/69644
More information about the Mlir-commits
mailing list