[Mlir-commits] [mlir] bd2083c - [mlir][Python] Python API cleanups and additions found during code audit.
Stella Laurenzo
llvmlistbot at llvm.org
Sun Nov 29 18:11:56 PST 2020
Author: Stella Laurenzo
Date: 2020-11-29T18:09:07-08:00
New Revision: bd2083c2fa7bb8769ca997a0303da54432e08519
URL: https://github.com/llvm/llvm-project/commit/bd2083c2fa7bb8769ca997a0303da54432e08519
DIFF: https://github.com/llvm/llvm-project/commit/bd2083c2fa7bb8769ca997a0303da54432e08519.diff
LOG: [mlir][Python] Python API cleanups and additions found during code audit.
* Add capsule get/create for Attribute and Type, which already had capsule interop defined.
* Add capsule interop and get/create for Location.
* Add Location __eq__.
* Use get() and implicit cast to go from PyAttribute, PyType, PyLocation to MlirAttribute, MlirType, MlirLocation (bundled with this change because I didn't want to continue the pattern one more time).
Differential Revision: https://reviews.llvm.org/D92283
Added:
Modified:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/Bindings/Python/PybindUtils.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/ir_attributes.py
mlir/test/Bindings/Python/ir_location.py
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 05519d804e31..31265edfb550 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -28,6 +28,7 @@
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr"
@@ -106,6 +107,24 @@ static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
return context;
}
+/** Creates a capsule object encapsulating the raw C-API MlirLocation.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the location in any way. */
+static inline PyObject *mlirPythonLocationToCapsule(MlirLocation loc) {
+ return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(loc),
+ MLIR_PYTHON_CAPSULE_LOCATION, NULL);
+}
+
+/** Extracts an MlirLocation from a capsule as produced from
+ * mlirPythonLocationToCapsule. If the capsule is not of the right type, then
+ * a null module is returned (as checked via mlirLocationIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+static inline MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_LOCATION);
+ MlirLocation loc = {ptr};
+ return loc;
+}
+
/** Creates a capsule object encapsulating the raw C-API MlirModule.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 902b2b988622..e3bfe76560f1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -153,6 +153,14 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);
/// Gets the context that a location was created with.
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location);
+/// Checks if the location is null.
+static inline int mlirLocationIsNull(MlirLocation location) {
+ return !location.ptr;
+}
+
+/// Checks if two locations are equal.
+MLIR_CAPI_EXPORTED int mlirLocationEqual(MlirLocation l1, MlirLocation l2);
+
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index e145a58d0d27..d34fe998583f 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -289,7 +289,7 @@ class PyBlockList {
llvm::SmallVector<MlirType, 4> argTypes;
argTypes.reserve(pyArgTypes.size());
for (auto &pyArg : pyArgTypes) {
- argTypes.push_back(pyArg.cast<PyType &>().type);
+ argTypes.push_back(pyArg.cast<PyType &>());
}
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
@@ -640,6 +640,18 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
// PyLocation
//------------------------------------------------------------------------------
+py::object PyLocation::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
+}
+
+PyLocation PyLocation::createFromCapsule(py::object capsule) {
+ MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
+ if (mlirLocationIsNull(rawLoc))
+ throw py::error_already_set();
+ return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
+ rawLoc);
+}
+
py::object PyLocation::contextEnter() {
return PyThreadContextEntry::pushLocation(*this);
}
@@ -879,7 +891,7 @@ py::object PyOperation::create(
// TODO: Verify result type originate from the same context.
if (!result)
throw SetPyError(PyExc_ValueError, "result type cannot be None");
- mlirResults.push_back(result->type);
+ mlirResults.push_back(*result);
}
}
// Unpack/validate attributes.
@@ -890,7 +902,7 @@ py::object PyOperation::create(
auto name = it.first.cast<std::string>();
auto &attribute = it.second.cast<PyAttribute &>();
// TODO: Verify attribute originates from the same context.
- mlirAttributes.emplace_back(std::move(name), attribute.attr);
+ mlirAttributes.emplace_back(std::move(name), attribute);
}
}
// Unpack/validate successors.
@@ -908,7 +920,7 @@ py::object PyOperation::create(
// Apply unpacked/validated to the operation state. Beyond this
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state =
- mlirOperationStateGet(toMlirStringRef(name), location->loc);
+ mlirOperationStateGet(toMlirStringRef(name), location);
if (!mlirOperands.empty())
mlirOperationStateAddOperands(&state, mlirOperands.size(),
mlirOperands.data());
@@ -1076,6 +1088,18 @@ bool PyAttribute::operator==(const PyAttribute &other) {
return mlirAttributeEqual(attr, other.attr);
}
+py::object PyAttribute::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
+}
+
+PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
+ MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
+ if (mlirAttributeIsNull(rawAttr))
+ throw py::error_already_set();
+ return PyAttribute(
+ PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
+}
+
//------------------------------------------------------------------------------
// PyNamedAttribute.
//------------------------------------------------------------------------------
@@ -1093,6 +1117,18 @@ bool PyType::operator==(const PyType &other) {
return mlirTypeEqual(type, other.type);
}
+py::object PyType::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
+}
+
+PyType PyType::createFromCapsule(py::object capsule) {
+ MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
+ if (mlirTypeIsNull(rawType))
+ throw py::error_already_set();
+ return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
+ rawType);
+}
+
//------------------------------------------------------------------------------
// PyValue and subclases.
//------------------------------------------------------------------------------
@@ -1315,7 +1351,7 @@ class PyOpAttributeMap {
void dunderSetItem(const std::string &name, PyAttribute attr) {
mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
- attr.attr);
+ attr);
}
void dunderDelItem(const std::string &name) {
@@ -1378,13 +1414,13 @@ class PyConcreteAttribute : public BaseTy {
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
- if (!DerivedTy::isaFunction(orig.attr)) {
+ if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError,
llvm::Twine("Cannot cast attribute to ") +
DerivedTy::pyClassName + " (from " + origRepr + ")");
}
- return orig.attr;
+ return orig;
}
static void bind(py::module &m) {
@@ -1408,8 +1444,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
c.def_static(
"get",
[](PyType &type, double value, DefaultingPyLocation loc) {
- MlirAttribute attr =
- mlirFloatAttrDoubleGetChecked(type.type, value, loc->loc);
+ MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(attr)) {
@@ -1443,7 +1478,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
c.def_property_readonly(
"value",
[](PyFloatAttribute &self) {
- return mlirFloatAttrGetValueDouble(self.attr);
+ return mlirFloatAttrGetValueDouble(self);
},
"Returns the value of the float point attribute");
}
@@ -1460,7 +1495,7 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
c.def_static(
"get",
[](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type.type, value);
+ MlirAttribute attr = mlirIntegerAttrGet(type, value);
return PyIntegerAttribute(type.getContext(), attr);
},
py::arg("type"), py::arg("value"),
@@ -1468,7 +1503,7 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
c.def_property_readonly(
"value",
[](PyIntegerAttribute &self) {
- return mlirIntegerAttrGetValueInt(self.attr);
+ return mlirIntegerAttrGetValueInt(self);
},
"Returns the value of the integer attribute");
}
@@ -1492,7 +1527,7 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
"Gets an uniqued bool attribute");
c.def_property_readonly(
"value",
- [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
+ [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
"Returns the value of the bool attribute");
}
};
@@ -1517,7 +1552,7 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
"get_typed",
[](PyType &type, std::string value) {
MlirAttribute attr =
- mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
+ mlirStringAttrTypedGet(type, value.size(), &value[0]);
return PyStringAttribute(type.getContext(), attr);
},
@@ -1525,7 +1560,7 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
c.def_property_readonly(
"value",
[](PyStringAttribute &self) {
- MlirStringRef stringRef = mlirStringAttrGetValue(self.attr);
+ MlirStringRef stringRef = mlirStringAttrGetValue(self);
return py::str(stringRef.data, stringRef.length);
},
"Returns the value of the string attribute");
@@ -1621,8 +1656,8 @@ class PyDenseElementsAttribute
PyAttribute &elementAttr) {
auto contextWrapper =
PyMlirContext::forContext(mlirTypeGetContext(shapedType));
- if (!mlirAttributeIsAInteger(elementAttr.attr) &&
- !mlirAttributeIsAFloat(elementAttr.attr)) {
+ if (!mlirAttributeIsAInteger(elementAttr) &&
+ !mlirAttributeIsAFloat(elementAttr)) {
std::string message = "Illegal element type for DenseElementsAttr: ";
message.append(py::repr(py::cast(elementAttr)));
throw SetPyError(PyExc_ValueError, message);
@@ -1634,8 +1669,8 @@ class PyDenseElementsAttribute
message.append(py::repr(py::cast(shapedType)));
throw SetPyError(PyExc_ValueError, message);
}
- MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType.type);
- MlirType attrType = mlirAttributeGetType(elementAttr.attr);
+ MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+ MlirType attrType = mlirAttributeGetType(elementAttr);
if (!mlirTypeEqual(shapedElementType, attrType)) {
std::string message =
"Shaped element type and attribute type must be equal: shaped=";
@@ -1646,14 +1681,14 @@ class PyDenseElementsAttribute
}
MlirAttribute elements =
- mlirDenseElementsAttrSplatGet(shapedType.type, elementAttr.attr);
+ mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}
- intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
+ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
py::buffer_info accessBuffer() {
- MlirType shapedType = mlirAttributeGetType(this->attr);
+ MlirType shapedType = mlirAttributeGetType(*this);
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
if (mlirTypeIsAF32(elementType)) {
@@ -1699,7 +1734,7 @@ class PyDenseElementsAttribute
"Gets a DenseElementsAttr where all values are the same")
.def_property_readonly("is_splat",
[](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self.attr);
+ return mlirDenseElementsAttrIsSplat(self);
})
.def_buffer(&PyDenseElementsAttribute::accessBuffer);
}
@@ -1742,7 +1777,7 @@ class PyDenseElementsAttribute
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
Type *data = static_cast<Type *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(this->attr)));
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
for (intptr_t i = 0; i < rank; ++i)
@@ -1782,7 +1817,7 @@ class PyDenseIntElementsAttribute
"attempt to access out of bounds element");
}
- MlirType type = mlirAttributeGetType(attr);
+ MlirType type = mlirAttributeGetType(*this);
type = mlirShapedTypeGetElementType(type);
assert(mlirTypeIsAInteger(type) &&
"expected integer element type in dense int elements attribute");
@@ -1795,23 +1830,23 @@ class PyDenseIntElementsAttribute
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
if (isUnsigned) {
if (width == 1) {
- return mlirDenseElementsAttrGetBoolValue(attr, pos);
+ return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 32) {
- return mlirDenseElementsAttrGetUInt32Value(attr, pos);
+ return mlirDenseElementsAttrGetUInt32Value(*this, pos);
}
if (width == 64) {
- return mlirDenseElementsAttrGetUInt64Value(attr, pos);
+ return mlirDenseElementsAttrGetUInt64Value(*this, pos);
}
} else {
if (width == 1) {
- return mlirDenseElementsAttrGetBoolValue(attr, pos);
+ return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 32) {
- return mlirDenseElementsAttrGetInt32Value(attr, pos);
+ return mlirDenseElementsAttrGetInt32Value(*this, pos);
}
if (width == 64) {
- return mlirDenseElementsAttrGetInt64Value(attr, pos);
+ return mlirDenseElementsAttrGetInt64Value(*this, pos);
}
}
throw SetPyError(PyExc_TypeError, "Unsupported integer type");
@@ -1838,7 +1873,7 @@ class PyDenseFPElementsAttribute
"attempt to access out of bounds element");
}
- MlirType type = mlirAttributeGetType(attr);
+ MlirType type = mlirAttributeGetType(*this);
type = mlirShapedTypeGetElementType(type);
// Dispatch element extraction to an appropriate C function based on the
// elemental type of the attribute. py::float_ is implicitly constructible
@@ -1846,10 +1881,10 @@ class PyDenseFPElementsAttribute
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
if (mlirTypeIsAF32(type)) {
- return mlirDenseElementsAttrGetFloatValue(attr, pos);
+ return mlirDenseElementsAttrGetFloatValue(*this, pos);
}
if (mlirTypeIsAF64(type)) {
- return mlirDenseElementsAttrGetDoubleValue(attr, pos);
+ return mlirDenseElementsAttrGetDoubleValue(*this, pos);
}
throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
}
@@ -1906,13 +1941,13 @@ class PyConcreteType : public BaseTy {
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
static MlirType castFrom(PyType &orig) {
- if (!DerivedTy::isaFunction(orig.type)) {
+ if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
- return orig.type;
+ return orig;
}
static void bind(py::module &m) {
@@ -1958,24 +1993,24 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
"Create an unsigned integer type");
c.def_property_readonly(
"width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
"Returns the width of the integer type");
c.def_property_readonly(
"is_signless",
[](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self.type);
+ return mlirIntegerTypeIsSignless(self);
},
"Returns whether this is a signless integer");
c.def_property_readonly(
"is_signed",
[](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self.type);
+ return mlirIntegerTypeIsSigned(self);
},
"Returns whether this is a signed integer");
c.def_property_readonly(
"is_unsigned",
[](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self.type);
+ return mlirIntegerTypeIsUnsigned(self);
},
"Returns whether this is an unsigned integer");
}
@@ -2101,8 +2136,8 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
"get",
[](PyType &elementType) {
// The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
- MlirType t = mlirComplexTypeGet(elementType.type);
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ MlirType t = mlirComplexTypeGet(elementType);
return PyComplexType(elementType.getContext(), t);
}
throw SetPyError(
@@ -2115,7 +2150,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
c.def_property_readonly(
"element_type",
[](PyComplexType &self) -> PyType {
- MlirType t = mlirComplexTypeGetElementType(self.type);
+ MlirType t = mlirComplexTypeGetElementType(self);
return PyType(self.getContext(), t);
},
"Returns element type.");
@@ -2132,34 +2167,32 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
c.def_property_readonly(
"element_type",
[](PyShapedType &self) {
- MlirType t = mlirShapedTypeGetElementType(self.type);
+ MlirType t = mlirShapedTypeGetElementType(self);
return PyType(self.getContext(), t);
},
"Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
- [](PyShapedType &self) -> bool {
- return mlirShapedTypeHasRank(self.type);
- },
+ [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
"Returns whether the given shaped type is ranked.");
c.def_property_readonly(
"rank",
[](PyShapedType &self) {
self.requireHasRank();
- return mlirShapedTypeGetRank(self.type);
+ return mlirShapedTypeGetRank(self);
},
"Returns the rank of the given ranked shaped type.");
c.def_property_readonly(
"has_static_shape",
[](PyShapedType &self) -> bool {
- return mlirShapedTypeHasStaticShape(self.type);
+ return mlirShapedTypeHasStaticShape(self);
},
"Returns whether the given shaped type has a static shape.");
c.def(
"is_dynamic_dim",
[](PyShapedType &self, intptr_t dim) -> bool {
self.requireHasRank();
- return mlirShapedTypeIsDynamicDim(self.type, dim);
+ return mlirShapedTypeIsDynamicDim(self, dim);
},
"Returns whether the dim-th dimension of the given shaped type is "
"dynamic.");
@@ -2167,7 +2200,7 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
"get_dim_size",
[](PyShapedType &self, intptr_t dim) {
self.requireHasRank();
- return mlirShapedTypeGetDimSize(self.type, dim);
+ return mlirShapedTypeGetDimSize(self, dim);
},
"Returns the dim-th dimension of the given ranked shaped type.");
c.def_static(
@@ -2187,7 +2220,7 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
private:
void requireHasRank() {
- if (!mlirShapedTypeHasRank(type)) {
+ if (!mlirShapedTypeHasRank(*this)) {
throw SetPyError(
PyExc_ValueError,
"calling this method requires that the type has a rank.");
@@ -2208,7 +2241,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
- elementType.type, loc->loc);
+ elementType, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@@ -2239,7 +2272,7 @@ class PyRankedTensorType
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
- shape.size(), shape.data(), elementType.type, loc->loc);
+ shape.size(), shape.data(), elementType, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@@ -2270,8 +2303,7 @@ class PyUnrankedTensorType
c.def_static(
"get",
[](PyType &elementType, DefaultingPyLocation loc) {
- MlirType t =
- mlirUnrankedTensorTypeGetChecked(elementType.type, loc->loc);
+ MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@@ -2306,8 +2338,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
[](PyType &elementType, std::vector<int64_t> shape,
unsigned memorySpace, DefaultingPyLocation loc) {
MlirType t = mlirMemRefTypeContiguousGetChecked(
- elementType.type, shape.size(), shape.data(), memorySpace,
- loc->loc);
+ elementType, shape.size(), shape.data(), memorySpace, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@@ -2326,14 +2357,14 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
- return mlirMemRefTypeGetNumAffineMaps(self.type);
+ return mlirMemRefTypeGetNumAffineMaps(self);
},
"Returns the number of affine layout maps in the given MemRef "
"type.")
.def_property_readonly(
"memory_space",
[](PyMemRefType &self) -> unsigned {
- return mlirMemRefTypeGetMemorySpace(self.type);
+ return mlirMemRefTypeGetMemorySpace(self);
},
"Returns the memory space of the given MemRef type.");
}
@@ -2352,8 +2383,8 @@ class PyUnrankedMemRefType
"get",
[](PyType &elementType, unsigned memorySpace,
DefaultingPyLocation loc) {
- MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
- memorySpace, loc->loc);
+ MlirType t =
+ mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@@ -2372,7 +2403,7 @@ class PyUnrankedMemRefType
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> unsigned {
- return mlirUnrankedMemrefGetMemorySpace(self.type);
+ return mlirUnrankedMemrefGetMemorySpace(self);
},
"Returns the memory space of the given Unranked MemRef type.");
}
@@ -2393,7 +2424,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
// Mapping py::list to SmallVector.
SmallVector<MlirType, 4> elements;
for (auto element : elementList)
- elements.push_back(element.cast<PyType>().type);
+ elements.push_back(element.cast<PyType>());
MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
return PyTupleType(context->getRef(), t);
},
@@ -2402,14 +2433,14 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) -> PyType {
- MlirType t = mlirTupleTypeGetType(self.type, pos);
+ MlirType t = mlirTupleTypeGetType(self, pos);
return PyType(self.getContext(), t);
},
"Returns the pos-th type in the tuple type.");
c.def_property_readonly(
"num_types",
[](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self.type);
+ return mlirTupleTypeGetNumTypes(self);
},
"Returns the number of types contained in a tuple.");
}
@@ -2439,11 +2470,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
c.def_property_readonly(
"inputs",
[](PyFunctionType &self) {
- MlirType t = self.type;
+ MlirType t = self;
auto contextRef = self.getContext();
py::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type);
- i < e; ++i) {
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
}
return types;
@@ -2452,12 +2483,12 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
c.def_property_readonly(
"results",
[](PyFunctionType &self) {
- MlirType t = self.type;
auto contextRef = self.getContext();
py::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type);
- i < e; ++i) {
- types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i)));
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(
+ PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
}
return types;
},
@@ -2584,8 +2615,15 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Location
//----------------------------------------------------------------------------
py::class_<PyLocation>(m, "Location")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
.def("__enter__", &PyLocation::contextEnter)
.def("__exit__", &PyLocation::contextExit)
+ .def("__eq__",
+ [](PyLocation &self, PyLocation &other) -> bool {
+ return mlirLocationEqual(self, other);
+ })
+ .def("__eq__", [](PyLocation &self, py::object other) { return false; })
.def_property_readonly_static(
"current",
[](py::object & /*class*/) {
@@ -2620,7 +2658,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"Context that owns the Location")
.def("__repr__", [](PyLocation &self) {
PyPrintAccumulator printAccum;
- mlirLocationPrint(self.loc, printAccum.getCallback(),
+ mlirLocationPrint(self, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
});
@@ -2650,7 +2688,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"create",
[](DefaultingPyLocation loc) {
- MlirModule module = mlirModuleCreateEmpty(loc->loc);
+ MlirModule module = mlirModuleCreateEmpty(loc);
return PyModule::forModule(module).releaseObject();
},
py::arg("loc") = py::none(), "Creates an empty module")
@@ -2881,6 +2919,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of PyAttribute.
//----------------------------------------------------------------------------
py::class_<PyAttribute>(m, "Attribute")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyAttribute::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
@@ -2904,25 +2945,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly("type",
[](PyAttribute &self) {
return PyType(self.getContext()->getRef(),
- mlirAttributeGetType(self.attr));
+ mlirAttributeGetType(self));
})
.def(
"get_named",
[](PyAttribute &self, std::string name) {
- return PyNamedAttribute(self.attr, std::move(name));
+ return PyNamedAttribute(self, std::move(name));
},
py::keep_alive<0, 1>(), "Binds a name to the attribute")
.def("__eq__",
[](PyAttribute &self, PyAttribute &other) { return self == other; })
.def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
.def(
- "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
+ "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
kDumpDocstring)
.def(
"__str__",
[](PyAttribute &self) {
PyPrintAccumulator printAccum;
- mlirAttributePrint(self.attr, printAccum.getCallback(),
+ mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
@@ -2935,7 +2976,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// being excessive.
PyPrintAccumulator printAccum;
printAccum.parts.append("Attribute(");
- mlirAttributePrint(self.attr, printAccum.getCallback(),
+ mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
@@ -2990,6 +3031,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of PyType.
//----------------------------------------------------------------------------
py::class_<PyType>(m, "Type")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
.def_static(
"parse",
[](std::string typeSpec, DefaultingPyMlirContext context) {
@@ -3012,12 +3055,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def("__eq__", [](PyType &self, PyType &other) { return self == other; })
.def("__eq__", [](PyType &self, py::object &other) { return false; })
.def(
- "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
+ "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
.def(
"__str__",
[](PyType &self) {
PyPrintAccumulator printAccum;
- mlirTypePrint(self.type, printAccum.getCallback(),
+ mlirTypePrint(self, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
@@ -3029,8 +3072,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// assembly forms and printing them is useful.
PyPrintAccumulator printAccum;
printAccum.parts.append("Type(");
- mlirTypePrint(self.type, printAccum.getCallback(),
- printAccum.getUserData());
+ mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
});
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 5236187c5b1f..d24607fb02c2 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -307,11 +307,24 @@ class PyLocation : public BaseContextObject {
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
: BaseContextObject(std::move(contextRef)), loc(loc) {}
+ operator MlirLocation() const { return loc; }
+ MlirLocation get() const { return loc; }
+
/// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(pybind11::object excType, pybind11::object excVal,
pybind11::object excTb);
+ /// Gets a capsule wrapping the void* within the MlirContext.
+ pybind11::object getCapsule();
+
+ /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
+ /// Note that PyMlirContext instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirContext
+ /// is taken by calling this function.
+ static PyLocation createFromCapsule(pybind11::object capsule);
+
+private:
MlirLocation loc;
};
@@ -324,6 +337,8 @@ class DefaultingPyLocation
static constexpr const char kTypeDescription[] =
"[ThreadContextAware] mlir.ir.Location";
static PyLocation &resolve();
+
+ operator MlirLocation() const { return *get(); }
};
/// Wrapper around MlirModule.
@@ -568,7 +583,19 @@ class PyAttribute : public BaseContextObject {
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseContextObject(std::move(contextRef)), attr(attr) {}
bool operator==(const PyAttribute &other);
+ operator MlirAttribute() const { return attr; }
+ MlirAttribute get() const { return attr; }
+ /// Gets a capsule wrapping the void* within the MlirContext.
+ pybind11::object getCapsule();
+
+ /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
+ /// Note that PyMlirContext instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirContext
+ /// is taken by calling this function.
+ static PyAttribute createFromCapsule(pybind11::object capsule);
+
+private:
MlirAttribute attr;
};
@@ -603,7 +630,18 @@ class PyType : public BaseContextObject {
: BaseContextObject(std::move(contextRef)), type(type) {}
bool operator==(const PyType &other);
operator MlirType() const { return type; }
+ MlirType get() const { return type; }
+
+ /// Gets a capsule wrapping the void* within the MlirContext.
+ pybind11::object getCapsule();
+ /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
+ /// Note that PyMlirContext instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirContext
+ /// is taken by calling this function.
+ static PyType createFromCapsule(pybind11::object capsule);
+
+private:
MlirType type;
};
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 4116e9f30b6b..0cea24482dfe 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -50,7 +50,7 @@ class Defaulting {
Defaulting() = default;
Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
- ReferrentTy *get() { return referrent; }
+ ReferrentTy *get() const { return referrent; }
ReferrentTy *operator->() { return referrent; }
private:
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c5eec20d85c4..475a062cb3f5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -119,6 +119,10 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context)));
}
+int mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
+ return unwrap(l1) == unwrap(l2);
+}
+
MlirContext mlirLocationGetContext(MlirLocation location) {
return wrap(unwrap(location).getContext());
}
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 11ad735f054b..6773c23cf8b3 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -74,6 +74,20 @@ def testAttrEqDoesNotRaise():
run(testAttrEqDoesNotRaise)
+# CHECK-LABEL: TEST: testAttrCapsule
+def testAttrCapsule():
+ with Context() as ctx:
+ a1 = Attribute.parse('"attr1"')
+ # CHECK: mlir.ir.Attribute._CAPIPtr
+ attr_capsule = a1._CAPIPtr
+ print(attr_capsule)
+ a2 = Attribute._CAPICreate(attr_capsule)
+ assert a2 == a1
+ assert a2.context is ctx
+
+run(testAttrCapsule)
+
+
# CHECK-LABEL: TEST: testStandardAttrCasts
def testStandardAttrCasts():
with Context():
diff --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py
index e0d1bf299f5b..42a96a1ba604 100644
--- a/mlir/test/Bindings/Python/ir_location.py
+++ b/mlir/test/Bindings/Python/ir_location.py
@@ -38,3 +38,16 @@ def testFileLineCol():
run(testFileLineCol)
+
+# CHECK-LABEL: TEST: testLocationCapsule
+def testLocationCapsule():
+ with Context() as ctx:
+ loc1 = Location.file("foo.txt", 123, 56)
+ # CHECK: mlir.ir.Location._CAPIPtr
+ loc_capsule = loc1._CAPIPtr
+ print(loc_capsule)
+ loc2 = Location._CAPICreate(loc_capsule)
+ assert loc2 == loc1
+ assert loc2.context is ctx
+
+run(testLocationCapsule)
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index ff058cb3bf93..2f37d08aee32 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -74,6 +74,20 @@ def testTypeEqDoesNotRaise():
run(testTypeEqDoesNotRaise)
+# CHECK-LABEL: TEST: testTypeCapsule
+def testTypeCapsule():
+ with Context() as ctx:
+ t1 = Type.parse("i32", ctx)
+ # CHECK: mlir.ir.Type._CAPIPtr
+ type_capsule = t1._CAPIPtr
+ print(type_capsule)
+ t2 = Type._CAPICreate(type_capsule)
+ assert t2 == t1
+ assert t2.context is ctx
+
+run(testTypeCapsule)
+
+
# CHECK-LABEL: TEST: testStandardTypeCasts
def testStandardTypeCasts():
ctx = Context()
More information about the Mlir-commits
mailing list