[Mlir-commits] [mlir] d39a784 - [MLIR][python bindings] Expose TypeIDs in python

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 22 11:21:37 PDT 2023


Author: max
Date: 2023-05-22T13:19:54-05:00
New Revision: d39a7844028bcdd28f72b0e69becc9c49b8fd283

URL: https://github.com/llvm/llvm-project/commit/d39a7844028bcdd28f72b0e69becc9c49b8fd283
DIFF: https://github.com/llvm/llvm-project/commit/d39a7844028bcdd28f72b0e69becc9c49b8fd283.diff

LOG: [MLIR][python bindings] Expose TypeIDs in python

This diff adds python bindings for `MlirTypeID`. It paves the way for returning accurately typed `Type`s from python APIs (see D150927) and then further along building type "conscious" `Value` APIs (see D150413).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D150839

Added: 
    

Modified: 
    mlir/include/mlir-c/Bindings/Python/Interop.h
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/python/mlir/dialects/python_test.py
    mlir/test/python/dialects/python_test.py
    mlir/test/python/ir/builtin_types.py
    mlir/test/python/lib/PythonTestModule.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index b877f94aa48d5..6ebb458082d7c 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -80,6 +80,8 @@
 #define MLIR_PYTHON_CAPSULE_PASS_MANAGER                                       \
   MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr")
 #define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr")
+#define MLIR_PYTHON_CAPSULE_TYPEID                                             \
+  MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr")
 
 /** Attribute on MLIR Python objects that expose their C-API pointer.
  * This will be a type-specific capsule created as per one of the helpers
@@ -268,6 +270,25 @@ static inline MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule) {
   return op;
 }
 
+/** Creates a capsule object encapsulating the raw C-API MlirTypeID.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the type in any way.
+ */
+static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID),
+                       MLIR_PYTHON_CAPSULE_TYPEID, NULL);
+}
+
+/** Extracts an MlirTypeID from a capsule as produced from
+ * mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then
+ * a null type is returned (as checked via mlirTypeIDIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID);
+  MlirTypeID typeID = {ptr};
+  return typeID;
+}
+
 /** Creates a capsule object encapsulating the raw C-API MlirType.
  * The returned capsule does not extend or affect ownership of any Python
  * objects that reference the type in any way.

diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 2b7606f3d9caf..4348c5ba167f9 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -22,6 +22,9 @@ extern "C" {
 // Integer types.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Integer type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void);
+
 /// Checks whether the given type is an integer type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type);
 
@@ -56,6 +59,9 @@ MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type);
 // Index type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Index type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void);
+
 /// Checks whether the given type is an index type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type);
 
@@ -67,6 +73,9 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Float8E5M2 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
+
 /// Checks whether the given type is an f8E5M2 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
 
@@ -74,6 +83,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E4M3FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void);
+
 /// Checks whether the given type is an f8E4M3FN type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
 
@@ -81,6 +93,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E5M2FNUZ type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void);
+
 /// Checks whether the given type is an f8E5M2FNUZ type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
 
@@ -88,6 +103,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E4M3FNUZ type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void);
+
 /// Checks whether the given type is an f8E4M3FNUZ type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
 
@@ -95,6 +113,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E4M3B11FNUZ type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void);
+
 /// Checks whether the given type is an f8E4M3B11FNUZ type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
 
@@ -102,6 +123,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an BFloat16 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
+
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 
@@ -109,6 +133,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float16 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void);
+
 /// Checks whether the given type is an f16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
 
@@ -116,6 +143,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float32 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void);
+
 /// Checks whether the given type is an f32 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
 
@@ -123,6 +153,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float64 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void);
+
 /// Checks whether the given type is an f64 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
 
@@ -134,6 +167,9 @@ MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
 // None type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an None type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void);
+
 /// Checks whether the given type is a None type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type);
 
@@ -145,6 +181,9 @@ MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx);
 // Complex type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Complex type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void);
+
 /// Checks whether the given type is a Complex type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type);
 
@@ -159,6 +198,9 @@ MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type);
 // Shaped type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Shaped type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void);
+
 /// Checks whether the given type is a Shaped type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type);
 
@@ -202,6 +244,9 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);
 // Vector type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Vector type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void);
+
 /// Checks whether the given type is a Vector type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type);
 
@@ -226,9 +271,15 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
 /// Checks whether the given type is a Tensor type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type);
 
+/// Returns the typeID of an RankedTensor type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void);
+
 /// Checks whether the given type is a ranked tensor type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type);
 
+/// Returns the typeID of an UnrankedTensor type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void);
+
 /// Checks whether the given type is an unranked tensor type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type);
 
@@ -264,9 +315,15 @@ mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType);
 // Ranked / Unranked MemRef type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an MemRef type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void);
+
 /// Checks whether the given type is a MemRef type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type);
 
+/// Returns the typeID of an UnrankedMemRef type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void);
+
 /// Checks whether the given type is an UnrankedMemRef type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
 
@@ -326,6 +383,9 @@ mlirUnrankedMemrefGetMemorySpace(MlirType type);
 // Tuple type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Tuple type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void);
+
 /// Checks whether the given type is a tuple type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type);
 
@@ -345,6 +405,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
 // Function type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Function type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void);
+
 /// Checks whether the given type is a function type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type);
 
@@ -373,6 +436,9 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type,
 // Opaque type.
 //===----------------------------------------------------------------------===//
 
+/// Returns the typeID of an Opaque type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void);
+
 /// Checks whether the given type is an opaque type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type);
 

diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index bec3fc76e39d2..ccca3aa0172e6 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -236,6 +236,27 @@ struct type_caster<MlirPassManager> {
   }
 };
 
+/// Casts object <-> MlirTypeID.
+template <>
+struct type_caster<MlirTypeID> {
+  PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToTypeID(capsule.ptr());
+    return !mlirTypeIDIsNull(value);
+  }
+  static handle cast(MlirTypeID v, return_value_policy, handle) {
+    if (v.ptr == nullptr)
+      return py::none();
+    py::object capsule =
+        py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v));
+    return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+        .attr("TypeID")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  };
+};
+
 /// Casts object <-> MlirType.
 template <>
 struct type_caster<MlirType> {

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7ffa464009fc8..db8390abee925 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -17,6 +17,7 @@
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -1807,6 +1808,24 @@ PyType PyType::createFromCapsule(py::object capsule) {
                 rawType);
 }
 
+//------------------------------------------------------------------------------
+// PyTypeID.
+//------------------------------------------------------------------------------
+
+py::object PyTypeID::getCapsule() {
+  return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
+}
+
+PyTypeID PyTypeID::createFromCapsule(py::object capsule) {
+  MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
+  if (mlirTypeIDIsNull(mlirTypeID))
+    throw py::error_already_set();
+  return PyTypeID(mlirTypeID);
+}
+bool PyTypeID::operator==(const PyTypeID &other) const {
+  return mlirTypeIDEqual(typeID, other.typeID);
+}
+
 //------------------------------------------------------------------------------
 // PyValue and subclases.
 //------------------------------------------------------------------------------
@@ -3268,16 +3287,47 @@ void mlir::python::populateIRCore(py::module &m) {
             return printAccum.join();
           },
           "Returns the assembly form of the type.")
-      .def("__repr__", [](PyType &self) {
-        // Generally, assembly formats are not printed for __repr__ because
-        // this can cause exceptionally long debug output and exceptions.
-        // However, types are an exception as they typically have compact
-        // assembly forms and printing them is useful.
-        PyPrintAccumulator printAccum;
-        printAccum.parts.append("Type(");
-        mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
-        printAccum.parts.append(")");
-        return printAccum.join();
+      .def("__repr__",
+           [](PyType &self) {
+             // Generally, assembly formats are not printed for __repr__ because
+             // this can cause exceptionally long debug output and exceptions.
+             // However, types are an exception as they typically have compact
+             // assembly forms and printing them is useful.
+             PyPrintAccumulator printAccum;
+             printAccum.parts.append("Type(");
+             mlirTypePrint(self, printAccum.getCallback(),
+                           printAccum.getUserData());
+             printAccum.parts.append(")");
+             return printAccum.join();
+           })
+      .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
+        MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+        if (!mlirTypeIDIsNull(mlirTypeID))
+          return mlirTypeID;
+        auto origRepr =
+            pybind11::repr(pybind11::cast(self)).cast<std::string>();
+        throw py::value_error(
+            (origRepr + llvm::Twine(" has no typeid.")).str());
+      });
+
+  //----------------------------------------------------------------------------
+  // Mapping of PyTypeID.
+  //----------------------------------------------------------------------------
+  py::class_<PyTypeID>(m, "TypeID", py::module_local())
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
+      // Note, this tests whether the underlying TypeIDs are the same,
+      // not whether the wrapper MlirTypeIDs are the same, nor whether
+      // the Python objects are the same (i.e., PyTypeID is a value type).
+      .def("__eq__",
+           [](PyTypeID &self, PyTypeID &other) { return self == other; })
+      .def("__eq__",
+           [](PyTypeID &self, const py::object &other) { return false; })
+      // Note, this gives the hash value of the underlying TypeID, not the
+      // hash value of the Python object, nor the hash value of the
+      // MlirTypeID wrapper.
+      .def("__hash__", [](PyTypeID &self) {
+        return static_cast<size_t>(mlirTypeIDHashValue(self));
       });
 
   //----------------------------------------------------------------------------

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ade790ba0ed13..fa529c43444d3 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -20,6 +20,7 @@
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
 
 namespace mlir {
@@ -826,6 +827,29 @@ class PyType : public BaseContextObject {
   MlirType type;
 };
 
+/// A TypeID provides an efficient and unique identifier for a specific C++
+/// type. This allows for a C++ type to be compared, hashed, and stored in an
+/// opaque context. This class wraps around the generic MlirTypeID.
+class PyTypeID {
+public:
+  PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
+  // Note, this tests whether the underlying TypeIDs are the same,
+  // not whether the wrapper MlirTypeIDs are the same, nor whether
+  // the PyTypeID objects are the same (i.e., PyTypeID is a value type).
+  bool operator==(const PyTypeID &other) const;
+  operator MlirTypeID() const { return typeID; }
+  MlirTypeID get() { return typeID; }
+
+  /// Gets a capsule wrapping the void* within the MlirTypeID.
+  pybind11::object getCapsule();
+
+  /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
+  static PyTypeID createFromCapsule(pybind11::object capsule);
+
+private:
+  MlirTypeID typeID;
+};
+
 /// CRTP base classes for Python types that subclass Type and should be
 /// castable from it (i.e. via something like IntegerType(t)).
 /// By default, type class hierarchies are one level deep (i.e. a
@@ -839,10 +863,14 @@ class PyConcreteType : public BaseTy {
   //   const char *pyClassName
   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
   using IsAFunctionTy = bool (*)(MlirType);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
 
   PyConcreteType() = default;
   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
-      : BaseTy(std::move(contextRef), t) {}
+      : BaseTy(std::move(contextRef), t) {
+    pybind11::implicitly_convertible<PyType, DerivedTy>();
+  }
   PyConcreteType(PyType &orig)
       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
 
@@ -866,6 +894,26 @@ class PyConcreteType : public BaseTy {
           return DerivedTy::isaFunction(otherType);
         },
         pybind11::arg("other"));
+    cls.def_property_readonly_static(
+        "static_typeid", [](py::object & /*class*/) -> MlirTypeID {
+          if (DerivedTy::getTypeIdFunction)
+            return DerivedTy::getTypeIdFunction();
+          throw SetPyError(PyExc_AttributeError,
+                           DerivedTy::pyClassName +
+                               llvm::Twine(" has no typeid."));
+        });
+    cls.def_property_readonly("typeid", [](PyType &self) {
+      return py::cast(self).attr("typeid").cast<MlirTypeID>();
+    });
+    cls.def("__repr__", [](DerivedTy &self) {
+      PyPrintAccumulator printAccum;
+      printAccum.parts.append(DerivedTy::pyClassName);
+      printAccum.parts.append("(");
+      mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
+      printAccum.parts.append(")");
+      return printAccum.join();
+    });
+
     DerivedTy::bindDerived(cls);
   }
 

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index cb62a402dc671..f45b30250c174 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -32,6 +32,8 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
 class PyIntegerType : public PyConcreteType<PyIntegerType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIntegerTypeGetTypeID;
   static constexpr const char *pyClassName = "IntegerType";
   using PyConcreteType::PyConcreteType;
 
@@ -89,6 +91,8 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
 class PyIndexType : public PyConcreteType<PyIndexType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIndexTypeGetTypeID;
   static constexpr const char *pyClassName = "IndexType";
   using PyConcreteType::PyConcreteType;
 
@@ -107,6 +111,8 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
 class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3FNTypeGetTypeID;
   static constexpr const char *pyClassName = "Float8E4M3FNType";
   using PyConcreteType::PyConcreteType;
 
@@ -125,6 +131,8 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
 class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E5M2TypeGetTypeID;
   static constexpr const char *pyClassName = "Float8E5M2Type";
   using PyConcreteType::PyConcreteType;
 
@@ -143,6 +151,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
 class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3FNUZTypeGetTypeID;
   static constexpr const char *pyClassName = "Float8E4M3FNUZType";
   using PyConcreteType::PyConcreteType;
 
@@ -161,6 +171,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
 class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3B11FNUZTypeGetTypeID;
   static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
   using PyConcreteType::PyConcreteType;
 
@@ -179,6 +191,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
 class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E5M2FNUZTypeGetTypeID;
   static constexpr const char *pyClassName = "Float8E5M2FNUZType";
   using PyConcreteType::PyConcreteType;
 
@@ -197,6 +211,8 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
 class PyBF16Type : public PyConcreteType<PyBF16Type> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirBFloat16TypeGetTypeID;
   static constexpr const char *pyClassName = "BF16Type";
   using PyConcreteType::PyConcreteType;
 
@@ -215,6 +231,8 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
 class PyF16Type : public PyConcreteType<PyF16Type> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat16TypeGetTypeID;
   static constexpr const char *pyClassName = "F16Type";
   using PyConcreteType::PyConcreteType;
 
@@ -233,6 +251,8 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
 class PyF32Type : public PyConcreteType<PyF32Type> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat32TypeGetTypeID;
   static constexpr const char *pyClassName = "F32Type";
   using PyConcreteType::PyConcreteType;
 
@@ -251,6 +271,8 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
 class PyF64Type : public PyConcreteType<PyF64Type> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat64TypeGetTypeID;
   static constexpr const char *pyClassName = "F64Type";
   using PyConcreteType::PyConcreteType;
 
@@ -269,6 +291,8 @@ class PyF64Type : public PyConcreteType<PyF64Type> {
 class PyNoneType : public PyConcreteType<PyNoneType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirNoneTypeGetTypeID;
   static constexpr const char *pyClassName = "NoneType";
   using PyConcreteType::PyConcreteType;
 
@@ -287,6 +311,8 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
 class PyComplexType : public PyConcreteType<PyComplexType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirComplexTypeGetTypeID;
   static constexpr const char *pyClassName = "ComplexType";
   using PyConcreteType::PyConcreteType;
 
@@ -417,6 +443,8 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirVectorTypeGetTypeID;
   static constexpr const char *pyClassName = "VectorType";
   using PyConcreteType::PyConcreteType;
 
@@ -442,6 +470,8 @@ class PyRankedTensorType
     : public PyConcreteType<PyRankedTensorType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirRankedTensorTypeGetTypeID;
   static constexpr const char *pyClassName = "RankedTensorType";
   using PyConcreteType::PyConcreteType;
 
@@ -476,6 +506,8 @@ class PyUnrankedTensorType
     : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnrankedTensorTypeGetTypeID;
   static constexpr const char *pyClassName = "UnrankedTensorType";
   using PyConcreteType::PyConcreteType;
 
@@ -498,6 +530,8 @@ class PyUnrankedTensorType
 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirMemRefTypeGetTypeID;
   static constexpr const char *pyClassName = "MemRefType";
   using PyConcreteType::PyConcreteType;
 
@@ -550,6 +584,8 @@ class PyUnrankedMemRefType
     : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnrankedMemRefTypeGetTypeID;
   static constexpr const char *pyClassName = "UnrankedMemRefType";
   using PyConcreteType::PyConcreteType;
 
@@ -585,6 +621,8 @@ class PyUnrankedMemRefType
 class PyTupleType : public PyConcreteType<PyTupleType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTupleTypeGetTypeID;
   static constexpr const char *pyClassName = "TupleType";
   using PyConcreteType::PyConcreteType;
 
@@ -622,6 +660,8 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
 class PyFunctionType : public PyConcreteType<PyFunctionType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFunctionTypeGetTypeID;
   static constexpr const char *pyClassName = "FunctionType";
   using PyConcreteType::PyConcreteType;
 
@@ -676,6 +716,8 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirOpaqueTypeGetTypeID;
   static constexpr const char *pyClassName = "OpaqueType";
   using PyConcreteType::PyConcreteType;
 

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 90ab847606ee0..1925478c66d41 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -22,6 +22,8 @@ using namespace mlir;
 // Integer types.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }
+
 bool mlirTypeIsAInteger(MlirType type) {
   return llvm::isa<IntegerType>(unwrap(type));
 }
@@ -58,6 +60,8 @@ bool mlirIntegerTypeIsUnsigned(MlirType type) {
 // Index type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }
+
 bool mlirTypeIsAIndex(MlirType type) {
   return llvm::isa<IndexType>(unwrap(type));
 }
@@ -70,6 +74,10 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
+  return wrap(Float8E5M2Type::getTypeID());
+}
+
 bool mlirTypeIsAFloat8E5M2(MlirType type) {
   return unwrap(type).isFloat8E5M2();
 }
@@ -78,6 +86,10 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
+  return wrap(Float8E4M3FNType::getTypeID());
+}
+
 bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
   return unwrap(type).isFloat8E4M3FN();
 }
@@ -86,6 +98,10 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
+  return wrap(Float8E5M2FNUZType::getTypeID());
+}
+
 bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
   return unwrap(type).isFloat8E5M2FNUZ();
 }
@@ -94,6 +110,10 @@ MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
+  return wrap(Float8E4M3FNUZType::getTypeID());
+}
+
 bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
   return unwrap(type).isFloat8E4M3FNUZ();
 }
@@ -102,6 +122,10 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
+  return wrap(Float8E4M3B11FNUZType::getTypeID());
+}
+
 bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
   return unwrap(type).isFloat8E4M3B11FNUZ();
 }
@@ -110,24 +134,34 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
 }
 
+MlirTypeID mlirBFloat16TypeGetTypeID() {
+  return wrap(BFloat16Type::getTypeID());
+}
+
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {
   return wrap(FloatType::getBF16(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
+
 bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
 
 MlirType mlirF16TypeGet(MlirContext ctx) {
   return wrap(FloatType::getF16(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
+
 bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
 
 MlirType mlirF32TypeGet(MlirContext ctx) {
   return wrap(FloatType::getF32(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
+
 bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
 
 MlirType mlirF64TypeGet(MlirContext ctx) {
@@ -138,6 +172,8 @@ MlirType mlirF64TypeGet(MlirContext ctx) {
 // None type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }
+
 bool mlirTypeIsANone(MlirType type) {
   return llvm::isa<NoneType>(unwrap(type));
 }
@@ -150,6 +186,8 @@ MlirType mlirNoneTypeGet(MlirContext ctx) {
 // Complex type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }
+
 bool mlirTypeIsAComplex(MlirType type) {
   return llvm::isa<ComplexType>(unwrap(type));
 }
@@ -214,6 +252,8 @@ int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
 // Vector type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }
+
 bool mlirTypeIsAVector(MlirType type) {
   return llvm::isa<VectorType>(unwrap(type));
 }
@@ -239,10 +279,18 @@ bool mlirTypeIsATensor(MlirType type) {
   return llvm::isa<TensorType>(unwrap(type));
 }
 
+MlirTypeID mlirRankedTensorTypeGetTypeID() {
+  return wrap(RankedTensorType::getTypeID());
+}
+
 bool mlirTypeIsARankedTensor(MlirType type) {
   return llvm::isa<RankedTensorType>(unwrap(type));
 }
 
+MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
+  return wrap(UnrankedTensorType::getTypeID());
+}
+
 bool mlirTypeIsAUnrankedTensor(MlirType type) {
   return llvm::isa<UnrankedTensorType>(unwrap(type));
 }
@@ -280,6 +328,8 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
 // Ranked / Unranked MemRef type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }
+
 bool mlirTypeIsAMemRef(MlirType type) {
   return llvm::isa<MemRefType>(unwrap(type));
 }
@@ -337,6 +387,10 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
   return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
 }
 
+MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
+  return wrap(UnrankedMemRefType::getTypeID());
+}
+
 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
   return llvm::isa<UnrankedMemRefType>(unwrap(type));
 }
@@ -362,6 +416,8 @@ MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
 // Tuple type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }
+
 bool mlirTypeIsATuple(MlirType type) {
   return llvm::isa<TupleType>(unwrap(type));
 }
@@ -386,6 +442,10 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
 // Function type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirFunctionTypeGetTypeID() {
+  return wrap(FunctionType::getTypeID());
+}
+
 bool mlirTypeIsAFunction(MlirType type) {
   return llvm::isa<FunctionType>(unwrap(type));
 }
@@ -424,6 +484,8 @@ MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
 // Opaque type.
 //===----------------------------------------------------------------------===//
 
+MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }
+
 bool mlirTypeIsAOpaque(MlirType type) {
   return llvm::isa<OpaqueType>(unwrap(type));
 }

diff  --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 5d42ddc47a242..ca0d479f1f5fc 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
 
 def register_python_test_dialect(context, load=True):
   from .._mlir_libs import _mlirPythonTest

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 6cde96e1da10d..2ca79b29f567c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -299,6 +299,14 @@ def testCustomType():
 
     # The following cast must not assert.
     b = test.TestType(a)
+    # Instance custom types should have typeids
+    assert isinstance(b.typeid, TypeID)
+    # Subclasses of ir.Type should not have a static_typeid
+    # CHECK: 'TestType' object has no attribute 'static_typeid'
+    try:
+      b.static_typeid
+    except AttributeError as e:
+      print(e)
 
     i8 = IntegerType.get_signless(8)
     try:
@@ -353,6 +361,12 @@ def __str__(self):
       # CHECK: False
       print(tt.is_null())
 
+      # Classes of custom types that inherit from concrete types should have
+      # static_typeid
+      assert isinstance(test.TestTensorType.static_typeid, TypeID)
+      # And it should be equal to the in-tree concrete type
+      assert test.TestTensorType.static_typeid == t.type.typeid
+
 
 # CHECK-LABEL: TEST: inferReturnTypeComponents
 @run

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index e383a78f40b8a..19e21fff8dba0 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -3,6 +3,7 @@
 import gc
 from mlir.ir import *
 
+
 def run(f):
   print("\nTEST:", f.__name__)
   f()
@@ -76,6 +77,7 @@ def testTypeHash():
   # CHECK: len(s): 2
   print("len(s): ", len(s))
 
+
 # CHECK-LABEL: TEST: testTypeCast
 @run
 def testTypeCast():
@@ -182,6 +184,7 @@ def testIntegerType():
     # CHECK: unsigned: ui64
     print("unsigned:", IntegerType.get_unsigned(64))
 
+
 # CHECK-LABEL: TEST: testIndexType
 @run
 def testIndexType():
@@ -259,7 +262,8 @@ def testConcreteShapedType():
     # CHECK: rank: 2
     print("rank:", vector.rank)
     # CHECK: whether the shaped type has a static shape: True
-    print("whether the shaped type has a static shape:", vector.has_static_shape)
+    print("whether the shaped type has a static shape:",
+          vector.has_static_shape)
     # CHECK: whether the dim-th dimension is dynamic: False
     print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
     # CHECK: dim size: 3
@@ -311,8 +315,7 @@ def testRankedTensorType():
     shape = [2, 3]
     loc = Location.unknown()
     # CHECK: ranked tensor type: tensor<2x3xf32>
-    print("ranked tensor type:",
-          RankedTensorType.get(shape, f32))
+    print("ranked tensor type:", RankedTensorType.get(shape, f32))
 
     none = NoneType.get()
     try:
@@ -477,8 +480,7 @@ def testTupleType():
 @run
 def testFunctionType():
   with Context() as ctx:
-    input_types = [IntegerType.get_signless(32),
-                  IntegerType.get_signless(16)]
+    input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
     result_types = [IndexType.get()]
     func = FunctionType.get(input_types, result_types)
     # CHECK: INPUTS: [Type(i32), Type(i16)]
@@ -509,3 +511,91 @@ def testShapedTypeConstants():
   print(type(ShapedType.get_dynamic_size()))
   # CHECK: <class 'int'>
   print(type(ShapedType.get_dynamic_stride_or_offset()))
+
+
+# CHECK-LABEL: TEST: testTypeIDs
+ at run
+def testTypeIDs():
+  with Context(), Location.unknown():
+    f32 = F32Type.get()
+
+    types = [
+      (IntegerType, IntegerType.get_signless(16)),
+      (IndexType, IndexType.get()),
+      (Float8E4M3FNType, Float8E4M3FNType.get()),
+      (Float8E5M2Type, Float8E5M2Type.get()),
+      (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
+      (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
+      (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
+      (BF16Type, BF16Type.get()),
+      (F16Type, F16Type.get()),
+      (F32Type, F32Type.get()),
+      (F64Type, F64Type.get()),
+      (NoneType, NoneType.get()),
+      (ComplexType, ComplexType.get(f32)),
+      (VectorType, VectorType.get([2, 3], f32)),
+      (RankedTensorType, RankedTensorType.get([2, 3], f32)),
+      (UnrankedTensorType, UnrankedTensorType.get(f32)),
+      (MemRefType, MemRefType.get([2, 3], f32)),
+      (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
+      (TupleType, TupleType.get_tuple([f32])),
+      (FunctionType, FunctionType.get([], [])),
+      (OpaqueType, OpaqueType.get("tensor", "bob")),
+    ]
+
+    # CHECK: IntegerType(i16)
+    # CHECK: IndexType(index)
+    # CHECK: Float8E4M3FNType(f8E4M3FN)
+    # CHECK: Float8E5M2Type(f8E5M2)
+    # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
+    # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
+    # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+    # CHECK: BF16Type(bf16)
+    # CHECK: F16Type(f16)
+    # CHECK: F32Type(f32)
+    # CHECK: F64Type(f64)
+    # CHECK: NoneType(none)
+    # CHECK: ComplexType(complex<f32>)
+    # CHECK: VectorType(vector<2x3xf32>)
+    # CHECK: RankedTensorType(tensor<2x3xf32>)
+    # CHECK: UnrankedTensorType(tensor<*xf32>)
+    # CHECK: MemRefType(memref<2x3xf32>)
+    # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
+    # CHECK: TupleType(tuple<f32>)
+    # CHECK: FunctionType(() -> ())
+    # CHECK: OpaqueType(!tensor.bob)
+    for _, t in types:
+      print(repr(t))
+
+    # Test getTypeIdFunction agrees with
+    # mlirTypeGetTypeID(self) for an instance.
+    # CHECK: all equal
+    for t1, t2 in types:
+      tid1, tid2 = t1.static_typeid, Type(t2).typeid
+      assert tid1 == tid2 and hash(tid1) == hash(
+          tid2), f"expected hash and value equality {t1} {t2}"
+    else:
+      print("all equal")
+
+    # Test that storing PyTypeID in python dicts
+    # works as expected.
+    typeid_dict = dict(types)
+    assert len(typeid_dict)
+
+    # CHECK: all equal
+    for t1, t2 in typeid_dict.items():
+      assert t1.static_typeid == t2.typeid and hash(
+          t1.static_typeid) == hash(
+              t2.typeid), f"expected hash and value equality {t1} {t2}"
+    else:
+      print("all equal")
+
+    # CHECK: ShapedType has no typeid.
+    try:
+      print(ShapedType.static_typeid)
+    except AttributeError as e:
+      print(e)
+
+    vector_type = Type.parse("vector<2x3xf32>")
+    # CHECK: True
+    print(ShapedType(vector_type).typeid == vector_type.typeid)

diff  --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f17f0821599c5..7edeaac86e45c 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinTypes.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
 namespace py = pybind11;
@@ -40,6 +41,9 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
+  mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
+                     py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+                         .attr("RankedTensorType"));
   mlir_value_subclass(m, "TestTensorValue",
                       mlirTypeIsAPythonTestTestTensorValue)
       .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });


        


More information about the Mlir-commits mailing list