[Mlir-commits] [mlir] bd2083c - [mlir][Python] Python API cleanups and additions found during code audit.

Stella Laurenzo llvmlistbot at llvm.org
Sun Nov 29 18:11:56 PST 2020


Author: Stella Laurenzo
Date: 2020-11-29T18:09:07-08:00
New Revision: bd2083c2fa7bb8769ca997a0303da54432e08519

URL: https://github.com/llvm/llvm-project/commit/bd2083c2fa7bb8769ca997a0303da54432e08519
DIFF: https://github.com/llvm/llvm-project/commit/bd2083c2fa7bb8769ca997a0303da54432e08519.diff

LOG: [mlir][Python] Python API cleanups and additions found during code audit.

* Add capsule get/create for Attribute and Type, which already had capsule interop defined.
* Add capsule interop and get/create for Location.
* Add Location __eq__.
* Use get() and implicit cast to go from PyAttribute, PyType, PyLocation to MlirAttribute, MlirType, MlirLocation (bundled with this change because I didn't want to continue the pattern one more time).

Differential Revision: https://reviews.llvm.org/D92283

Added: 
    

Modified: 
    mlir/include/mlir-c/Bindings/Python/Interop.h
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/ir_attributes.py
    mlir/test/Bindings/Python/ir_location.py
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 05519d804e31..31265edfb550 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -28,6 +28,7 @@
 
 #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr"
@@ -106,6 +107,24 @@ static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
   return context;
 }
 
+/** Creates a capsule object encapsulating the raw C-API MlirLocation.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the location in any way. */
+static inline PyObject *mlirPythonLocationToCapsule(MlirLocation loc) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(loc),
+                       MLIR_PYTHON_CAPSULE_LOCATION, NULL);
+}
+
+/** Extracts an MlirLocation from a capsule as produced from
+ * mlirPythonLocationToCapsule. If the capsule is not of the right type, then
+ * a null module is returned (as checked via mlirLocationIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+static inline MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_LOCATION);
+  MlirLocation loc = {ptr};
+  return loc;
+}
+
 /** Creates a capsule object encapsulating the raw C-API MlirModule.
  * The returned capsule does not extend or affect ownership of any Python
  * objects that reference the module in any way. */

diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 902b2b988622..e3bfe76560f1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -153,6 +153,14 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);
 /// Gets the context that a location was created with.
 MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location);
 
+/// Checks if the location is null.
+static inline int mlirLocationIsNull(MlirLocation location) {
+  return !location.ptr;
+}
+
+/// Checks if two locations are equal.
+MLIR_CAPI_EXPORTED int mlirLocationEqual(MlirLocation l1, MlirLocation l2);
+
 /** Prints a location by sending chunks of the string representation and
  * forwarding `userData to `callback`. Note that the callback may be called
  * several times with consecutive chunks of the string. */

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index e145a58d0d27..d34fe998583f 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -289,7 +289,7 @@ class PyBlockList {
     llvm::SmallVector<MlirType, 4> argTypes;
     argTypes.reserve(pyArgTypes.size());
     for (auto &pyArg : pyArgTypes) {
-      argTypes.push_back(pyArg.cast<PyType &>().type);
+      argTypes.push_back(pyArg.cast<PyType &>());
     }
 
     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
@@ -640,6 +640,18 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
 // PyLocation
 //------------------------------------------------------------------------------
 
+py::object PyLocation::getCapsule() {
+  return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
+}
+
+PyLocation PyLocation::createFromCapsule(py::object capsule) {
+  MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
+  if (mlirLocationIsNull(rawLoc))
+    throw py::error_already_set();
+  return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
+                    rawLoc);
+}
+
 py::object PyLocation::contextEnter() {
   return PyThreadContextEntry::pushLocation(*this);
 }
@@ -879,7 +891,7 @@ py::object PyOperation::create(
       // TODO: Verify result type originate from the same context.
       if (!result)
         throw SetPyError(PyExc_ValueError, "result type cannot be None");
-      mlirResults.push_back(result->type);
+      mlirResults.push_back(*result);
     }
   }
   // Unpack/validate attributes.
@@ -890,7 +902,7 @@ py::object PyOperation::create(
       auto name = it.first.cast<std::string>();
       auto &attribute = it.second.cast<PyAttribute &>();
       // TODO: Verify attribute originates from the same context.
-      mlirAttributes.emplace_back(std::move(name), attribute.attr);
+      mlirAttributes.emplace_back(std::move(name), attribute);
     }
   }
   // Unpack/validate successors.
@@ -908,7 +920,7 @@ py::object PyOperation::create(
   // Apply unpacked/validated to the operation state. Beyond this
   // point, exceptions cannot be thrown or else the state will leak.
   MlirOperationState state =
-      mlirOperationStateGet(toMlirStringRef(name), location->loc);
+      mlirOperationStateGet(toMlirStringRef(name), location);
   if (!mlirOperands.empty())
     mlirOperationStateAddOperands(&state, mlirOperands.size(),
                                   mlirOperands.data());
@@ -1076,6 +1088,18 @@ bool PyAttribute::operator==(const PyAttribute &other) {
   return mlirAttributeEqual(attr, other.attr);
 }
 
+py::object PyAttribute::getCapsule() {
+  return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
+}
+
+PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
+  MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
+  if (mlirAttributeIsNull(rawAttr))
+    throw py::error_already_set();
+  return PyAttribute(
+      PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
+}
+
 //------------------------------------------------------------------------------
 // PyNamedAttribute.
 //------------------------------------------------------------------------------
@@ -1093,6 +1117,18 @@ bool PyType::operator==(const PyType &other) {
   return mlirTypeEqual(type, other.type);
 }
 
+py::object PyType::getCapsule() {
+  return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
+}
+
+PyType PyType::createFromCapsule(py::object capsule) {
+  MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
+  if (mlirTypeIsNull(rawType))
+    throw py::error_already_set();
+  return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
+                rawType);
+}
+
 //------------------------------------------------------------------------------
 // PyValue and subclases.
 //------------------------------------------------------------------------------
@@ -1315,7 +1351,7 @@ class PyOpAttributeMap {
 
   void dunderSetItem(const std::string &name, PyAttribute attr) {
     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
-                                    attr.attr);
+                                    attr);
   }
 
   void dunderDelItem(const std::string &name) {
@@ -1378,13 +1414,13 @@ class PyConcreteAttribute : public BaseTy {
       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
 
   static MlirAttribute castFrom(PyAttribute &orig) {
-    if (!DerivedTy::isaFunction(orig.attr)) {
+    if (!DerivedTy::isaFunction(orig)) {
       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
       throw SetPyError(PyExc_ValueError,
                        llvm::Twine("Cannot cast attribute to ") +
                            DerivedTy::pyClassName + " (from " + origRepr + ")");
     }
-    return orig.attr;
+    return orig;
   }
 
   static void bind(py::module &m) {
@@ -1408,8 +1444,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
     c.def_static(
         "get",
         [](PyType &type, double value, DefaultingPyLocation loc) {
-          MlirAttribute attr =
-              mlirFloatAttrDoubleGetChecked(type.type, value, loc->loc);
+          MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirAttributeIsNull(attr)) {
@@ -1443,7 +1478,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
     c.def_property_readonly(
         "value",
         [](PyFloatAttribute &self) {
-          return mlirFloatAttrGetValueDouble(self.attr);
+          return mlirFloatAttrGetValueDouble(self);
         },
         "Returns the value of the float point attribute");
   }
@@ -1460,7 +1495,7 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
     c.def_static(
         "get",
         [](PyType &type, int64_t value) {
-          MlirAttribute attr = mlirIntegerAttrGet(type.type, value);
+          MlirAttribute attr = mlirIntegerAttrGet(type, value);
           return PyIntegerAttribute(type.getContext(), attr);
         },
         py::arg("type"), py::arg("value"),
@@ -1468,7 +1503,7 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
     c.def_property_readonly(
         "value",
         [](PyIntegerAttribute &self) {
-          return mlirIntegerAttrGetValueInt(self.attr);
+          return mlirIntegerAttrGetValueInt(self);
         },
         "Returns the value of the integer attribute");
   }
@@ -1492,7 +1527,7 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
         "Gets an uniqued bool attribute");
     c.def_property_readonly(
         "value",
-        [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
+        [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
         "Returns the value of the bool attribute");
   }
 };
@@ -1517,7 +1552,7 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
         "get_typed",
         [](PyType &type, std::string value) {
           MlirAttribute attr =
-              mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
+              mlirStringAttrTypedGet(type, value.size(), &value[0]);
           return PyStringAttribute(type.getContext(), attr);
         },
 
@@ -1525,7 +1560,7 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
     c.def_property_readonly(
         "value",
         [](PyStringAttribute &self) {
-          MlirStringRef stringRef = mlirStringAttrGetValue(self.attr);
+          MlirStringRef stringRef = mlirStringAttrGetValue(self);
           return py::str(stringRef.data, stringRef.length);
         },
         "Returns the value of the string attribute");
@@ -1621,8 +1656,8 @@ class PyDenseElementsAttribute
                                            PyAttribute &elementAttr) {
     auto contextWrapper =
         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
-    if (!mlirAttributeIsAInteger(elementAttr.attr) &&
-        !mlirAttributeIsAFloat(elementAttr.attr)) {
+    if (!mlirAttributeIsAInteger(elementAttr) &&
+        !mlirAttributeIsAFloat(elementAttr)) {
       std::string message = "Illegal element type for DenseElementsAttr: ";
       message.append(py::repr(py::cast(elementAttr)));
       throw SetPyError(PyExc_ValueError, message);
@@ -1634,8 +1669,8 @@ class PyDenseElementsAttribute
       message.append(py::repr(py::cast(shapedType)));
       throw SetPyError(PyExc_ValueError, message);
     }
-    MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType.type);
-    MlirType attrType = mlirAttributeGetType(elementAttr.attr);
+    MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+    MlirType attrType = mlirAttributeGetType(elementAttr);
     if (!mlirTypeEqual(shapedElementType, attrType)) {
       std::string message =
           "Shaped element type and attribute type must be equal: shaped=";
@@ -1646,14 +1681,14 @@ class PyDenseElementsAttribute
     }
 
     MlirAttribute elements =
-        mlirDenseElementsAttrSplatGet(shapedType.type, elementAttr.attr);
+        mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
   }
 
-  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
+  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
 
   py::buffer_info accessBuffer() {
-    MlirType shapedType = mlirAttributeGetType(this->attr);
+    MlirType shapedType = mlirAttributeGetType(*this);
     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
 
     if (mlirTypeIsAF32(elementType)) {
@@ -1699,7 +1734,7 @@ class PyDenseElementsAttribute
                     "Gets a DenseElementsAttr where all values are the same")
         .def_property_readonly("is_splat",
                                [](PyDenseElementsAttribute &self) -> bool {
-                                 return mlirDenseElementsAttrIsSplat(self.attr);
+                                 return mlirDenseElementsAttrIsSplat(self);
                                })
         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
   }
@@ -1742,7 +1777,7 @@ class PyDenseElementsAttribute
     // Prepare the data for the buffer_info.
     // Buffer is configured for read-only access below.
     Type *data = static_cast<Type *>(
-        const_cast<void *>(mlirDenseElementsAttrGetRawData(this->attr)));
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
     // Prepare the shape for the buffer_info.
     SmallVector<intptr_t, 4> shape;
     for (intptr_t i = 0; i < rank; ++i)
@@ -1782,7 +1817,7 @@ class PyDenseIntElementsAttribute
                        "attempt to access out of bounds element");
     }
 
-    MlirType type = mlirAttributeGetType(attr);
+    MlirType type = mlirAttributeGetType(*this);
     type = mlirShapedTypeGetElementType(type);
     assert(mlirTypeIsAInteger(type) &&
            "expected integer element type in dense int elements attribute");
@@ -1795,23 +1830,23 @@ class PyDenseIntElementsAttribute
     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
     if (isUnsigned) {
       if (width == 1) {
-        return mlirDenseElementsAttrGetBoolValue(attr, pos);
+        return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
       if (width == 32) {
-        return mlirDenseElementsAttrGetUInt32Value(attr, pos);
+        return mlirDenseElementsAttrGetUInt32Value(*this, pos);
       }
       if (width == 64) {
-        return mlirDenseElementsAttrGetUInt64Value(attr, pos);
+        return mlirDenseElementsAttrGetUInt64Value(*this, pos);
       }
     } else {
       if (width == 1) {
-        return mlirDenseElementsAttrGetBoolValue(attr, pos);
+        return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
       if (width == 32) {
-        return mlirDenseElementsAttrGetInt32Value(attr, pos);
+        return mlirDenseElementsAttrGetInt32Value(*this, pos);
       }
       if (width == 64) {
-        return mlirDenseElementsAttrGetInt64Value(attr, pos);
+        return mlirDenseElementsAttrGetInt64Value(*this, pos);
       }
     }
     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
@@ -1838,7 +1873,7 @@ class PyDenseFPElementsAttribute
                        "attempt to access out of bounds element");
     }
 
-    MlirType type = mlirAttributeGetType(attr);
+    MlirType type = mlirAttributeGetType(*this);
     type = mlirShapedTypeGetElementType(type);
     // Dispatch element extraction to an appropriate C function based on the
     // elemental type of the attribute. py::float_ is implicitly constructible
@@ -1846,10 +1881,10 @@ class PyDenseFPElementsAttribute
     // TODO: consider caching the type properties in the constructor to avoid
     // querying them on each element access.
     if (mlirTypeIsAF32(type)) {
-      return mlirDenseElementsAttrGetFloatValue(attr, pos);
+      return mlirDenseElementsAttrGetFloatValue(*this, pos);
     }
     if (mlirTypeIsAF64(type)) {
-      return mlirDenseElementsAttrGetDoubleValue(attr, pos);
+      return mlirDenseElementsAttrGetDoubleValue(*this, pos);
     }
     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
   }
@@ -1906,13 +1941,13 @@ class PyConcreteType : public BaseTy {
       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
 
   static MlirType castFrom(PyType &orig) {
-    if (!DerivedTy::isaFunction(orig.type)) {
+    if (!DerivedTy::isaFunction(orig)) {
       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
                                              DerivedTy::pyClassName +
                                              " (from " + origRepr + ")");
     }
-    return orig.type;
+    return orig;
   }
 
   static void bind(py::module &m) {
@@ -1958,24 +1993,24 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
         "Create an unsigned integer type");
     c.def_property_readonly(
         "width",
-        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
+        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
         "Returns the width of the integer type");
     c.def_property_readonly(
         "is_signless",
         [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSignless(self.type);
+          return mlirIntegerTypeIsSignless(self);
         },
         "Returns whether this is a signless integer");
     c.def_property_readonly(
         "is_signed",
         [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSigned(self.type);
+          return mlirIntegerTypeIsSigned(self);
         },
         "Returns whether this is a signed integer");
     c.def_property_readonly(
         "is_unsigned",
         [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsUnsigned(self.type);
+          return mlirIntegerTypeIsUnsigned(self);
         },
         "Returns whether this is an unsigned integer");
   }
@@ -2101,8 +2136,8 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
         "get",
         [](PyType &elementType) {
           // The element must be a floating point or integer scalar type.
-          if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
-            MlirType t = mlirComplexTypeGet(elementType.type);
+          if (mlirTypeIsAIntegerOrFloat(elementType)) {
+            MlirType t = mlirComplexTypeGet(elementType);
             return PyComplexType(elementType.getContext(), t);
           }
           throw SetPyError(
@@ -2115,7 +2150,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
     c.def_property_readonly(
         "element_type",
         [](PyComplexType &self) -> PyType {
-          MlirType t = mlirComplexTypeGetElementType(self.type);
+          MlirType t = mlirComplexTypeGetElementType(self);
           return PyType(self.getContext(), t);
         },
         "Returns element type.");
@@ -2132,34 +2167,32 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
     c.def_property_readonly(
         "element_type",
         [](PyShapedType &self) {
-          MlirType t = mlirShapedTypeGetElementType(self.type);
+          MlirType t = mlirShapedTypeGetElementType(self);
           return PyType(self.getContext(), t);
         },
         "Returns the element type of the shaped type.");
     c.def_property_readonly(
         "has_rank",
-        [](PyShapedType &self) -> bool {
-          return mlirShapedTypeHasRank(self.type);
-        },
+        [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
         "Returns whether the given shaped type is ranked.");
     c.def_property_readonly(
         "rank",
         [](PyShapedType &self) {
           self.requireHasRank();
-          return mlirShapedTypeGetRank(self.type);
+          return mlirShapedTypeGetRank(self);
         },
         "Returns the rank of the given ranked shaped type.");
     c.def_property_readonly(
         "has_static_shape",
         [](PyShapedType &self) -> bool {
-          return mlirShapedTypeHasStaticShape(self.type);
+          return mlirShapedTypeHasStaticShape(self);
         },
         "Returns whether the given shaped type has a static shape.");
     c.def(
         "is_dynamic_dim",
         [](PyShapedType &self, intptr_t dim) -> bool {
           self.requireHasRank();
-          return mlirShapedTypeIsDynamicDim(self.type, dim);
+          return mlirShapedTypeIsDynamicDim(self, dim);
         },
         "Returns whether the dim-th dimension of the given shaped type is "
         "dynamic.");
@@ -2167,7 +2200,7 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
         "get_dim_size",
         [](PyShapedType &self, intptr_t dim) {
           self.requireHasRank();
-          return mlirShapedTypeGetDimSize(self.type, dim);
+          return mlirShapedTypeGetDimSize(self, dim);
         },
         "Returns the dim-th dimension of the given ranked shaped type.");
     c.def_static(
@@ -2187,7 +2220,7 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
 
 private:
   void requireHasRank() {
-    if (!mlirShapedTypeHasRank(type)) {
+    if (!mlirShapedTypeHasRank(*this)) {
       throw SetPyError(
           PyExc_ValueError,
           "calling this method requires that the type has a rank.");
@@ -2208,7 +2241,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
         [](std::vector<int64_t> shape, PyType &elementType,
            DefaultingPyLocation loc) {
           MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
-                                                elementType.type, loc->loc);
+                                                elementType, loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -2239,7 +2272,7 @@ class PyRankedTensorType
         [](std::vector<int64_t> shape, PyType &elementType,
            DefaultingPyLocation loc) {
           MlirType t = mlirRankedTensorTypeGetChecked(
-              shape.size(), shape.data(), elementType.type, loc->loc);
+              shape.size(), shape.data(), elementType, loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -2270,8 +2303,7 @@ class PyUnrankedTensorType
     c.def_static(
         "get",
         [](PyType &elementType, DefaultingPyLocation loc) {
-          MlirType t =
-              mlirUnrankedTensorTypeGetChecked(elementType.type, loc->loc);
+          MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -2306,8 +2338,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
          [](PyType &elementType, std::vector<int64_t> shape,
             unsigned memorySpace, DefaultingPyLocation loc) {
            MlirType t = mlirMemRefTypeContiguousGetChecked(
-               elementType.type, shape.size(), shape.data(), memorySpace,
-               loc->loc);
+               elementType, shape.size(), shape.data(), memorySpace, loc);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2326,14 +2357,14 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
         .def_property_readonly(
             "num_affine_maps",
             [](PyMemRefType &self) -> intptr_t {
-              return mlirMemRefTypeGetNumAffineMaps(self.type);
+              return mlirMemRefTypeGetNumAffineMaps(self);
             },
             "Returns the number of affine layout maps in the given MemRef "
             "type.")
         .def_property_readonly(
             "memory_space",
             [](PyMemRefType &self) -> unsigned {
-              return mlirMemRefTypeGetMemorySpace(self.type);
+              return mlirMemRefTypeGetMemorySpace(self);
             },
             "Returns the memory space of the given MemRef type.");
   }
@@ -2352,8 +2383,8 @@ class PyUnrankedMemRefType
          "get",
          [](PyType &elementType, unsigned memorySpace,
             DefaultingPyLocation loc) {
-           MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
-                                                         memorySpace, loc->loc);
+           MlirType t =
+               mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2372,7 +2403,7 @@ class PyUnrankedMemRefType
         .def_property_readonly(
             "memory_space",
             [](PyUnrankedMemRefType &self) -> unsigned {
-              return mlirUnrankedMemrefGetMemorySpace(self.type);
+              return mlirUnrankedMemrefGetMemorySpace(self);
             },
             "Returns the memory space of the given Unranked MemRef type.");
   }
@@ -2393,7 +2424,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
           // Mapping py::list to SmallVector.
           SmallVector<MlirType, 4> elements;
           for (auto element : elementList)
-            elements.push_back(element.cast<PyType>().type);
+            elements.push_back(element.cast<PyType>());
           MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
           return PyTupleType(context->getRef(), t);
         },
@@ -2402,14 +2433,14 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
     c.def(
         "get_type",
         [](PyTupleType &self, intptr_t pos) -> PyType {
-          MlirType t = mlirTupleTypeGetType(self.type, pos);
+          MlirType t = mlirTupleTypeGetType(self, pos);
           return PyType(self.getContext(), t);
         },
         "Returns the pos-th type in the tuple type.");
     c.def_property_readonly(
         "num_types",
         [](PyTupleType &self) -> intptr_t {
-          return mlirTupleTypeGetNumTypes(self.type);
+          return mlirTupleTypeGetNumTypes(self);
         },
         "Returns the number of types contained in a tuple.");
   }
@@ -2439,11 +2470,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
     c.def_property_readonly(
         "inputs",
         [](PyFunctionType &self) {
-          MlirType t = self.type;
+          MlirType t = self;
           auto contextRef = self.getContext();
           py::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type);
-               i < e; ++i) {
+          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+               ++i) {
             types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
           }
           return types;
@@ -2452,12 +2483,12 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
     c.def_property_readonly(
         "results",
         [](PyFunctionType &self) {
-          MlirType t = self.type;
           auto contextRef = self.getContext();
           py::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type);
-               i < e; ++i) {
-            types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i)));
+          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+               ++i) {
+            types.append(
+                PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
           }
           return types;
         },
@@ -2584,8 +2615,15 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of Location
   //----------------------------------------------------------------------------
   py::class_<PyLocation>(m, "Location")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
       .def("__enter__", &PyLocation::contextEnter)
       .def("__exit__", &PyLocation::contextExit)
+      .def("__eq__",
+           [](PyLocation &self, PyLocation &other) -> bool {
+             return mlirLocationEqual(self, other);
+           })
+      .def("__eq__", [](PyLocation &self, py::object other) { return false; })
       .def_property_readonly_static(
           "current",
           [](py::object & /*class*/) {
@@ -2620,7 +2658,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           "Context that owns the Location")
       .def("__repr__", [](PyLocation &self) {
         PyPrintAccumulator printAccum;
-        mlirLocationPrint(self.loc, printAccum.getCallback(),
+        mlirLocationPrint(self, printAccum.getCallback(),
                           printAccum.getUserData());
         return printAccum.join();
       });
@@ -2650,7 +2688,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_static(
           "create",
           [](DefaultingPyLocation loc) {
-            MlirModule module = mlirModuleCreateEmpty(loc->loc);
+            MlirModule module = mlirModuleCreateEmpty(loc);
             return PyModule::forModule(module).releaseObject();
           },
           py::arg("loc") = py::none(), "Creates an empty module")
@@ -2881,6 +2919,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of PyAttribute.
   //----------------------------------------------------------------------------
   py::class_<PyAttribute>(m, "Attribute")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyAttribute::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
       .def_static(
           "parse",
           [](std::string attrSpec, DefaultingPyMlirContext context) {
@@ -2904,25 +2945,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly("type",
                              [](PyAttribute &self) {
                                return PyType(self.getContext()->getRef(),
-                                             mlirAttributeGetType(self.attr));
+                                             mlirAttributeGetType(self));
                              })
       .def(
           "get_named",
           [](PyAttribute &self, std::string name) {
-            return PyNamedAttribute(self.attr, std::move(name));
+            return PyNamedAttribute(self, std::move(name));
           },
           py::keep_alive<0, 1>(), "Binds a name to the attribute")
       .def("__eq__",
            [](PyAttribute &self, PyAttribute &other) { return self == other; })
       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
       .def(
-          "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
+          "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
           kDumpDocstring)
       .def(
           "__str__",
           [](PyAttribute &self) {
             PyPrintAccumulator printAccum;
-            mlirAttributePrint(self.attr, printAccum.getCallback(),
+            mlirAttributePrint(self, printAccum.getCallback(),
                                printAccum.getUserData());
             return printAccum.join();
           },
@@ -2935,7 +2976,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
         // being excessive.
         PyPrintAccumulator printAccum;
         printAccum.parts.append("Attribute(");
-        mlirAttributePrint(self.attr, printAccum.getCallback(),
+        mlirAttributePrint(self, printAccum.getCallback(),
                            printAccum.getUserData());
         printAccum.parts.append(")");
         return printAccum.join();
@@ -2990,6 +3031,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of PyType.
   //----------------------------------------------------------------------------
   py::class_<PyType>(m, "Type")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
       .def_static(
           "parse",
           [](std::string typeSpec, DefaultingPyMlirContext context) {
@@ -3012,12 +3055,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
       .def("__eq__", [](PyType &self, py::object &other) { return false; })
       .def(
-          "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
+          "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
       .def(
           "__str__",
           [](PyType &self) {
             PyPrintAccumulator printAccum;
-            mlirTypePrint(self.type, printAccum.getCallback(),
+            mlirTypePrint(self, printAccum.getCallback(),
                           printAccum.getUserData());
             return printAccum.join();
           },
@@ -3029,8 +3072,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
         // assembly forms and printing them is useful.
         PyPrintAccumulator printAccum;
         printAccum.parts.append("Type(");
-        mlirTypePrint(self.type, printAccum.getCallback(),
-                      printAccum.getUserData());
+        mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
         printAccum.parts.append(")");
         return printAccum.join();
       });

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 5236187c5b1f..d24607fb02c2 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -307,11 +307,24 @@ class PyLocation : public BaseContextObject {
   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
       : BaseContextObject(std::move(contextRef)), loc(loc) {}
 
+  operator MlirLocation() const { return loc; }
+  MlirLocation get() const { return loc; }
+
   /// Enter and exit the context manager.
   pybind11::object contextEnter();
   void contextExit(pybind11::object excType, pybind11::object excVal,
                    pybind11::object excTb);
 
+  /// Gets a capsule wrapping the void* within the MlirContext.
+  pybind11::object getCapsule();
+
+  /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
+  /// Note that PyMlirContext instances are uniqued, so the returned object
+  /// may be a pre-existing object. Ownership of the underlying MlirContext
+  /// is taken by calling this function.
+  static PyLocation createFromCapsule(pybind11::object capsule);
+
+private:
   MlirLocation loc;
 };
 
@@ -324,6 +337,8 @@ class DefaultingPyLocation
   static constexpr const char kTypeDescription[] =
       "[ThreadContextAware] mlir.ir.Location";
   static PyLocation &resolve();
+
+  operator MlirLocation() const { return *get(); }
 };
 
 /// Wrapper around MlirModule.
@@ -568,7 +583,19 @@ class PyAttribute : public BaseContextObject {
   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
       : BaseContextObject(std::move(contextRef)), attr(attr) {}
   bool operator==(const PyAttribute &other);
+  operator MlirAttribute() const { return attr; }
+  MlirAttribute get() const { return attr; }
 
+  /// Gets a capsule wrapping the void* within the MlirContext.
+  pybind11::object getCapsule();
+
+  /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
+  /// Note that PyMlirContext instances are uniqued, so the returned object
+  /// may be a pre-existing object. Ownership of the underlying MlirContext
+  /// is taken by calling this function.
+  static PyAttribute createFromCapsule(pybind11::object capsule);
+
+private:
   MlirAttribute attr;
 };
 
@@ -603,7 +630,18 @@ class PyType : public BaseContextObject {
       : BaseContextObject(std::move(contextRef)), type(type) {}
   bool operator==(const PyType &other);
   operator MlirType() const { return type; }
+  MlirType get() const { return type; }
+
+  /// Gets a capsule wrapping the void* within the MlirContext.
+  pybind11::object getCapsule();
 
+  /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
+  /// Note that PyMlirContext instances are uniqued, so the returned object
+  /// may be a pre-existing object. Ownership of the underlying MlirContext
+  /// is taken by calling this function.
+  static PyType createFromCapsule(pybind11::object capsule);
+
+private:
   MlirType type;
 };
 

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 4116e9f30b6b..0cea24482dfe 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -50,7 +50,7 @@ class Defaulting {
   Defaulting() = default;
   Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
 
-  ReferrentTy *get() { return referrent; }
+  ReferrentTy *get() const { return referrent; }
   ReferrentTy *operator->() { return referrent; }
 
 private:

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c5eec20d85c4..475a062cb3f5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -119,6 +119,10 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
   return wrap(UnknownLoc::get(unwrap(context)));
 }
 
+int mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
+  return unwrap(l1) == unwrap(l2);
+}
+
 MlirContext mlirLocationGetContext(MlirLocation location) {
   return wrap(unwrap(location).getContext());
 }

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 11ad735f054b..6773c23cf8b3 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -74,6 +74,20 @@ def testAttrEqDoesNotRaise():
 run(testAttrEqDoesNotRaise)
 
 
+# CHECK-LABEL: TEST: testAttrCapsule
+def testAttrCapsule():
+  with Context() as ctx:
+    a1 = Attribute.parse('"attr1"')
+  # CHECK: mlir.ir.Attribute._CAPIPtr
+  attr_capsule = a1._CAPIPtr
+  print(attr_capsule)
+  a2 = Attribute._CAPICreate(attr_capsule)
+  assert a2 == a1
+  assert a2.context is ctx
+
+run(testAttrCapsule)
+
+
 # CHECK-LABEL: TEST: testStandardAttrCasts
 def testStandardAttrCasts():
   with Context():

diff  --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py
index e0d1bf299f5b..42a96a1ba604 100644
--- a/mlir/test/Bindings/Python/ir_location.py
+++ b/mlir/test/Bindings/Python/ir_location.py
@@ -38,3 +38,16 @@ def testFileLineCol():
 
 run(testFileLineCol)
 
+
+# CHECK-LABEL: TEST: testLocationCapsule
+def testLocationCapsule():
+  with Context() as ctx:
+    loc1 = Location.file("foo.txt", 123, 56)
+  # CHECK: mlir.ir.Location._CAPIPtr
+  loc_capsule = loc1._CAPIPtr
+  print(loc_capsule)
+  loc2 = Location._CAPICreate(loc_capsule)
+  assert loc2 == loc1
+  assert loc2.context is ctx
+
+run(testLocationCapsule)

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index ff058cb3bf93..2f37d08aee32 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -74,6 +74,20 @@ def testTypeEqDoesNotRaise():
 run(testTypeEqDoesNotRaise)
 
 
+# CHECK-LABEL: TEST: testTypeCapsule
+def testTypeCapsule():
+  with Context() as ctx:
+    t1 = Type.parse("i32", ctx)
+  # CHECK: mlir.ir.Type._CAPIPtr
+  type_capsule = t1._CAPIPtr
+  print(type_capsule)
+  t2 = Type._CAPICreate(type_capsule)
+  assert t2 == t1
+  assert t2.context is ctx
+
+run(testTypeCapsule)
+
+
 # CHECK-LABEL: TEST: testStandardTypeCasts
 def testStandardTypeCasts():
   ctx = Context()


        


More information about the Mlir-commits mailing list