[Mlir-commits] [mlir] bfb1ba7 - [MLIR][python bindings] Add TypeCaster for returning refined types from python APIs

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 26 09:02:13 PDT 2023


Author: max
Date: 2023-05-26T11:02:05-05:00
New Revision: bfb1ba752655bf09b35c486f6cc9817dbedfb1bb

URL: https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb
DIFF: https://github.com/llvm/llvm-project/commit/bfb1ba752655bf09b35c486f6cc9817dbedfb1bb.diff

LOG: [MLIR][python bindings] Add TypeCaster for returning refined types from python APIs

depends on D150839

This diff uses `MlirTypeID` to register `TypeCaster`s (i.e., `[](PyType pyType) -> DerivedTy { return pyType; }`) for all concrete types (i.e., `PyConcrete<...>`) that are then queried for (by `MlirTypeID`) and called in `struct type_caster<MlirType>::cast`. The result is that anywhere an `MlirType mlirType` is returned from a python binding, that `mlirType` is automatically cast to the correct concrete type. For example:

```
      c0 = arith.ConstantOp(f32, 0.0)
      # CHECK: F32Type(f32)
      print(repr(c0.result.type))

      unranked_tensor_type = UnrankedTensorType.get(f32)
      unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result

      # CHECK: UnrankedTensorType
      print(type(unranked_tensor.type).__name__)
      # CHECK: UnrankedTensorType(tensor<*xf32>)
      print(repr(unranked_tensor.type))
```

This functionality immediately extends to typed attributes (i.e., `attr.type`).

The diff also implements similar functionality for `mlir_type_subclass`es but in a slightly different way - for such types (which have no cpp corresponding `class` or `struct`) the user must provide a type caster in python (similar to how `AttrBuilder` works) or in cpp as a `py::cpp_function`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D150927

Added: 
    

Modified: 
    mlir/include/mlir-c/Bindings/Python/Interop.h
    mlir/include/mlir-c/Dialect/Transform.h
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/include/mlir/CAPI/Support.h
    mlir/lib/Bindings/Python/DialectTransform.cpp
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRInterfaces.cpp
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/CAPI/Dialect/Transform.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/CAPI/IR/Support.cpp
    mlir/python/mlir/dialects/python_test.py
    mlir/python/mlir/ir.py
    mlir/test/python/dialects/python_test.py
    mlir/test/python/ir/attributes.py
    mlir/test/python/ir/builtin_types.py
    mlir/test/python/lib/PythonTestCAPI.cpp
    mlir/test/python/lib/PythonTestCAPI.h
    mlir/test/python/lib/PythonTestModule.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 6ebb458082d7c..33332d6a34bb8 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -107,6 +107,23 @@
  * delineated). */
 #define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
 
+/** Attribute on MLIR Python objects that expose a function for downcasting the
+ * corresponding Python object to a subclass if the object is in fact a subclass
+ * (Concrete or mlir_type_subclass) of ir.Type. The signature of the function
+ * is: def maybe_downcast(self) -> object where the resulting object will
+ * (possibly) be an instance of the subclass.
+ */
+#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast"
+
+/** Attribute on main C extension module (_mlir) that corresponds to the
+ * type caster registration binding. The signature of the function is:
+ *   def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
+ *                              bool replace)
+ * where replace indicates the typeCaster should replace any existing registered
+ * type casters (such as those for upstream ConcreteTypes).
+ */
+#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
+
 /// Gets a void* from a wrapped struct. Needed because const cast is 
diff erent
 /// between C/C++.
 #ifdef __cplusplus

diff  --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 864dffa3f7824..0409890b21406 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -33,6 +33,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType
 mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
 

diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 13a3cb0130cea..8253981b3cda2 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -825,6 +825,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
 /// Gets the type ID of the type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
 
+/// Gets the dialect a type belongs to.
+MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type);
+
 /// Checks whether a type is null.
 static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }
 

diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index ccca3aa0172e6..272067a261edb 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -28,6 +28,7 @@
 #include "llvm/ADT/Twine.h"
 
 namespace py = pybind11;
+using namespace py::literals;
 
 // Raw CAPI type casters need to be declared before use, so always include them
 // first.
@@ -272,6 +273,7 @@ struct type_caster<MlirType> {
     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
         .attr("Type")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
         .release();
   }
 };
@@ -424,20 +426,24 @@ class mlir_attribute_subclass : public pure_subclass {
 class mlir_type_subclass : public pure_subclass {
 public:
   using IsAFunctionTy = bool (*)(MlirType);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
 
   /// Subclasses by looking up the super-class dynamically.
   mlir_type_subclass(py::handle scope, const char *typeClassName,
-                     IsAFunctionTy isaFunction)
+                     IsAFunctionTy isaFunction,
+                     GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : mlir_type_subclass(
             scope, typeClassName, isaFunction,
-            py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {}
+            py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
+            getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Type super-class. This must
   /// be used if the subclass is being defined in the same extension module
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_type_subclass(py::handle scope, const char *typeClassName,
-                     IsAFunctionTy isaFunction, const py::object &superCls)
+                     IsAFunctionTy isaFunction, const py::object &superCls,
+                     GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : pure_subclass(scope, typeClassName, superCls) {
     // Casting constructor. Note that it hard, if not impossible, to properly
     // call chain to parent `__init__` in pybind11 due to its special handling
@@ -471,6 +477,19 @@ class mlir_type_subclass : public pure_subclass {
         "isinstance",
         [isaFunction](MlirType other) { return isaFunction(other); },
         py::arg("other_type"));
+    def("__repr__", [superCls, captureTypeName](py::object self) {
+      return py::repr(superCls(self))
+          .attr("replace")(superCls.attr("__name__"), captureTypeName);
+    });
+    if (getTypeIDFunction) {
+      py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+          .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+              getTypeIDFunction(),
+              pybind11::cpp_function(
+                  [thisClass = thisClass](const py::object &mlirType) {
+                    return thisClass(mlirType);
+                  }));
+    }
   }
 };
 

diff  --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h
index f3e8a67e0ac36..e42413dbe6d28 100644
--- a/mlir/include/mlir/CAPI/Support.h
+++ b/mlir/include/mlir/CAPI/Support.h
@@ -44,4 +44,25 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
 DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
 DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
 
+namespace llvm {
+
+template <>
+struct DenseMapInfo<MlirTypeID> {
+  static inline MlirTypeID getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlirTypeIDCreate(pointer);
+  }
+  static inline MlirTypeID getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlirTypeIDCreate(pointer);
+  }
+  static inline unsigned getHashValue(const MlirTypeID &val) {
+    return mlirTypeIDHashValue(val);
+  }
+  static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
+    return mlirTypeIDEqual(lhs, rhs);
+  }
+};
+} // namespace llvm
+
 #endif // MLIR_CAPI_SUPPORT_H

diff  --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index a9db2428c6c3f..e4b8cee73ef16 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -36,7 +36,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto operationType =
-      mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
+      mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
+                         mlirTransformOperationTypeGetTypeID);
   operationType.def_classmethod(
       "get",
       [](py::object cls, const std::string &operationName, MlirContext ctx) {

diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 45d03689642ee..0fc7614ccad52 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,12 +9,15 @@
 #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
 #define MLIR_BINDINGS_PYTHON_GLOBALS_H
 
+#include <optional>
 #include <string>
 #include <vector>
-#include <optional>
 
 #include "PybindUtils.h"
 
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
 
@@ -54,16 +57,18 @@ class PyGlobals {
   /// entities.
   void loadDialectModule(llvm::StringRef dialectNamespace);
 
-  /// Decorator for registering a custom Dialect class. The class object must
-  /// have a DIALECT_NAMESPACE attribute.
-  pybind11::object registerDialectDecorator(pybind11::object pyClass);
-
   /// Adds a user-friendly Attribute builder.
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
   void registerAttributeBuilder(const std::string &attributeKind,
                                 pybind11::function pyFunc);
 
+  /// Adds a user-friendly type caster. Raises an exception if the mapping
+  /// already exists and replace == false. This is intended to be called by
+  /// implementation code.
+  void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
+                          bool replace = false);
+
   /// Adds a concrete implementation dialect class.
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
@@ -80,6 +85,10 @@ class PyGlobals {
   std::optional<pybind11::function>
   lookupAttributeBuilder(const std::string &attributeKind);
 
+  /// Returns the custom type caster for MlirTypeID mlirTypeID.
+  std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
+                                                     MlirDialect dialect);
+
   /// Looks up a registered dialect class by namespace. Note that this may
   /// trigger loading of the defining module and can arbitrarily re-enter.
   std::optional<pybind11::object>
@@ -101,6 +110,10 @@ class PyGlobals {
   llvm::StringMap<pybind11::object> operationClassMap;
   /// Map of attribute ODS name to custom builder.
   llvm::StringMap<pybind11::object> attributeBuilderMap;
+  /// Map of MlirTypeID to custom type caster.
+  llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
+  /// Cache for map of MlirTypeID to custom type caster.
+  llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
 
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 0ab47cc241d91..3c7926e784dbb 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
 
 namespace py = pybind11;
 using namespace mlir;
@@ -1023,8 +1024,7 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
         py::arg("value"), py::arg("context") = py::none(),
         "Gets a uniqued Type attribute");
     c.def_property_readonly("value", [](PyTypeAttribute &self) {
-      return PyType(self.getContext()->getRef(),
-                    mlirTypeAttrGetValue(self.get()));
+      return mlirTypeAttrGetValue(self.get());
     });
   }
 };

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a6bd4d849b97c..ec9066aa10cee 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -25,6 +25,7 @@
 #include <utility>
 
 namespace py = pybind11;
+using namespace py::literals;
 using namespace mlir;
 using namespace mlir::python;
 
@@ -2121,13 +2122,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
 
 /// Returns the list of types of the values held by container.
 template <typename Container>
-static std::vector<PyType> getValueTypes(Container &container,
-                                         PyMlirContextRef &context) {
-  std::vector<PyType> result;
+static std::vector<MlirType> getValueTypes(Container &container,
+                                           PyMlirContextRef &context) {
+  std::vector<MlirType> result;
   result.reserve(container.size());
   for (int i = 0, e = container.size(); i < e; ++i) {
-    result.push_back(
-        PyType(context, mlirValueGetType(container.getElement(i).get())));
+    result.push_back(mlirValueGetType(container.getElement(i).get()));
   }
   return result;
 }
@@ -3148,11 +3148,8 @@ void mlir::python::populateIRCore(py::module &m) {
           "context",
           [](PyAttribute &self) { return self.getContext().getObject(); },
           "Context that owns the Attribute")
-      .def_property_readonly("type",
-                             [](PyAttribute &self) {
-                               return PyType(self.getContext()->getRef(),
-                                             mlirAttributeGetType(self));
-                             })
+      .def_property_readonly(
+          "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
       .def(
           "get_named",
           [](PyAttribute &self, std::string name) {
@@ -3247,7 +3244,7 @@ void mlir::python::populateIRCore(py::module &m) {
                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
             if (mlirTypeIsNull(type))
               throw MLIRError("Unable to parse type", errors.take());
-            return PyType(context->getRef(), type);
+            return type;
           },
           py::arg("asm"), py::arg("context") = py::none(),
           kContextParseTypeDocstring)
@@ -3284,6 +3281,18 @@ void mlir::python::populateIRCore(py::module &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyType &self) {
+             MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+             assert(!mlirTypeIDIsNull(mlirTypeID) &&
+                    "mlirTypeID was expected to be non-null.");
+             std::optional<pybind11::function> typeCaster =
+                 PyGlobals::get().lookupTypeCaster(mlirTypeID,
+                                                   mlirTypeGetDialect(self));
+             if (!typeCaster)
+               return py::cast(self);
+             return typeCaster.value()(self);
+           })
       .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
         MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
         if (!mlirTypeIDIsNull(mlirTypeID))
@@ -3387,12 +3396,8 @@ void mlir::python::populateIRCore(py::module &m) {
             return printAccum.join();
           },
           py::arg("use_local_scope") = false, kGetNameAsOperand)
-      .def_property_readonly("type",
-                             [](PyValue &self) {
-                               return PyType(
-                                   self.getParentOperation()->getContext(),
-                                   mlirValueGetType(self.get()));
-                             })
+      .def_property_readonly(
+          "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
       .def(
           "replace_all_uses_with",
           [](PyValue &self, PyValue &with) {

diff  --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 0a7a25c0005fe..25fcaccd236db 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -321,11 +321,7 @@ class PyShapedTypeComponents {
                                        py::module_local())
         .def_property_readonly(
             "element_type",
-            [](PyShapedTypeComponents &self) {
-              return PyType(PyMlirContext::forContext(
-                                mlirTypeGetContext(self.elementType)),
-                            self.elementType);
-            },
+            [](PyShapedTypeComponents &self) { return self.elementType; },
             "Returns the element type of the shaped type components.")
         .def_static(
             "get",

diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 7c49f20f1a9ec..d9a66bce0fecb 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -14,6 +14,7 @@
 #include <vector>
 
 #include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Support.h"
 
 namespace py = pybind11;
 using namespace mlir;
@@ -72,6 +73,15 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
   found = std::move(pyFunc);
 }
 
+void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
+                                   pybind11::function typeCaster,
+                                   bool replace) {
+  pybind11::object &found = typeCasterMap[mlirTypeID];
+  if (found && !found.is_none() && !replace)
+    throw std::runtime_error("Type caster is already registered");
+  found = std::move(typeCaster);
+}
+
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     py::object pyClass) {
   py::object &found = dialectClassMap[dialectNamespace];
@@ -110,6 +120,39 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
   return std::nullopt;
 }
 
+std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
+                                                        MlirDialect dialect) {
+  {
+    // Fast match against the class map first (common case).
+    const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+    if (foundIt != typeCasterMapCache.end()) {
+      if (foundIt->second.is_none())
+        return std::nullopt;
+      assert(foundIt->second && "py::function is defined");
+      return foundIt->second;
+    }
+  }
+
+  // Not found. Load the dialect namespace.
+  loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+
+  // Attempt to find from the canonical map and cache.
+  {
+    const auto foundIt = typeCasterMap.find(mlirTypeID);
+    if (foundIt != typeCasterMap.end()) {
+      if (foundIt->second.is_none())
+        return std::nullopt;
+      assert(foundIt->second && "py::object is defined");
+      // Positive cache.
+      typeCasterMapCache[mlirTypeID] = foundIt->second;
+      return foundIt->second;
+    }
+    // Negative cache.
+    typeCasterMap[mlirTypeID] = py::none();
+    return std::nullopt;
+  }
+}
+
 std::optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   loadDialectModule(dialectNamespace);
@@ -164,4 +207,5 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
 void PyGlobals::clearImportCache() {
   loadedDialectModulesCache.clear();
   operationClassMapCache.clear();
+  typeCasterMapCache.clear();
 }

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index cfa3737cf30d2..013bb7b9256f4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -13,6 +13,7 @@
 #include <utility>
 #include <vector>
 
+#include "Globals.h"
 #include "PybindUtils.h"
 
 #include "mlir-c/AffineExpr.h"
@@ -868,9 +869,7 @@ class PyConcreteType : public BaseTy {
 
   PyConcreteType() = default;
   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
-      : BaseTy(std::move(contextRef), t) {
-    pybind11::implicitly_convertible<PyType, DerivedTy>();
-  }
+      : BaseTy(std::move(contextRef), t) {}
   PyConcreteType(PyType &orig)
       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
 
@@ -914,6 +913,13 @@ class PyConcreteType : public BaseTy {
       return printAccum.join();
     });
 
+    if (DerivedTy::getTypeIdFunction) {
+      PyGlobals::get().registerTypeCaster(
+          DerivedTy::getTypeIdFunction(),
+          pybind11::cpp_function(
+              [](PyType pyType) -> DerivedTy { return pyType; }));
+    }
+
     DerivedTy::bindDerived(cls);
   }
 
@@ -1009,9 +1015,8 @@ class PyConcreteAttribute : public BaseTy {
           return DerivedTy::isaFunction(otherAttr);
         },
         pybind11::arg("other"));
-    cls.def_property_readonly("type", [](PyAttribute &attr) {
-      return PyType(attr.getContext(), mlirAttributeGetType(attr));
-    });
+    cls.def_property_readonly(
+        "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
     DerivedTy::bindDerived(cls);
   }
 

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5c089b2f2c506..25307262bddbd 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -334,10 +334,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
         "Create a complex type");
     c.def_property_readonly(
         "element_type",
-        [](PyComplexType &self) -> PyType {
-          MlirType t = mlirComplexTypeGetElementType(self);
-          return PyType(self.getContext(), t);
-        },
+        [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
         "Returns element type.");
   }
 };
@@ -351,10 +348,7 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
   static void bindDerived(ClassTy &c) {
     c.def_property_readonly(
         "element_type",
-        [](PyShapedType &self) {
-          MlirType t = mlirShapedTypeGetElementType(self);
-          return PyType(self.getContext(), t);
-        },
+        [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
         "Returns the element type of the shaped type.");
     c.def_property_readonly(
         "has_rank",
@@ -641,9 +635,8 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
         "Create a tuple type");
     c.def(
         "get_type",
-        [](PyTupleType &self, intptr_t pos) -> PyType {
-          MlirType t = mlirTupleTypeGetType(self, pos);
-          return PyType(self.getContext(), t);
+        [](PyTupleType &self, intptr_t pos) {
+          return mlirTupleTypeGetType(self, pos);
         },
         py::arg("pos"), "Returns the pos-th type in the tuple type.");
     c.def_property_readonly(
@@ -686,7 +679,7 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
           py::list types;
           for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
                ++i) {
-            types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
+            types.append(mlirFunctionTypeGetInput(t, i));
           }
           return types;
         },
@@ -698,8 +691,7 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
           py::list types;
           for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
                ++i) {
-            types.append(
-                PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
+            types.append(mlirFunctionTypeGetResult(self, i));
           }
           return types;
         },

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b32b4186fcb9f..cdddfbe50606d 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -16,6 +16,7 @@
 
 namespace py = pybind11;
 using namespace mlir;
+using namespace py::literals;
 using namespace mlir::python;
 
 // -----------------------------------------------------------------------------
@@ -35,12 +36,12 @@ PYBIND11_MODULE(_mlir, m) {
             self.getDialectSearchPrefixes().push_back(std::move(moduleName));
             self.clearImportCache();
           },
-          py::arg("module_name"))
+          "module_name"_a)
       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
-           py::arg("dialect_namespace"), py::arg("dialect_class"),
+           "dialect_namespace"_a, "dialect_class"_a,
            "Testing hook for directly registering a dialect")
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
-           py::arg("operation_name"), py::arg("operation_class"),
+           "operation_name"_a, "operation_class"_a,
            "Testing hook for directly registering an operation");
 
   // Aside from making the globals accessible to python, having python manage
@@ -58,11 +59,11 @@ PYBIND11_MODULE(_mlir, m) {
         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
         return pyClass;
       },
-      py::arg("dialect_class"),
+      "dialect_class"_a,
       "Class decorator for registering a custom Dialect wrapper");
   m.def(
       "register_operation",
-      [](py::object dialectClass) -> py::cpp_function {
+      [](const py::object &dialectClass) -> py::cpp_function {
         return py::cpp_function(
             [dialectClass](py::object opClass) -> py::object {
               std::string operationName =
@@ -75,9 +76,17 @@ PYBIND11_MODULE(_mlir, m) {
               return opClass;
             });
       },
-      py::arg("dialect_class"),
+      "dialect_class"_a,
       "Produce a class decorator for registering an Operation class as part of "
       "a dialect");
+  m.def(
+      MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
+      [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
+        PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
+                                            replace);
+      },
+      "typeid"_a, "type_caster"_a, "replace"_a = false,
+      "Register a type caster for casting MLIR types to custom user types.");
 
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");

diff  --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 90594b67aacfb..d3cd4e3d0bb28 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -37,6 +37,10 @@ bool mlirTypeIsATransformOperationType(MlirType type) {
   return isa<transform::OperationType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
+  return wrap(transform::OperationType::getTypeID());
+}
+
 MlirType mlirTransformOperationTypeGet(MlirContext ctx,
                                        MlirStringRef operationName) {
   return wrap(

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 1925478c66d41..82c5b5a6147a4 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -324,6 +324,10 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
   return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
 }
 
+MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
+  return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked MemRef type.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c0cf5977736f4..373e01a1362c7 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -832,6 +832,10 @@ MlirTypeID mlirTypeGetTypeID(MlirType type) {
   return wrap(unwrap(type).getTypeID());
 }
 
+MlirDialect mlirTypeGetDialect(MlirType type) {
+  return wrap(&unwrap(type).getDialect());
+}
+
 bool mlirTypeEqual(MlirType t1, MlirType t2) {
   return unwrap(t1) == unwrap(t2);
 }

diff  --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index cbfbb54769aa9..ea081b2e99b59 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -23,7 +23,6 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
 //===----------------------------------------------------------------------===//
 // TypeID API.
 //===----------------------------------------------------------------------===//
-
 MlirTypeID mlirTypeIDCreate(const void *ptr) {
   assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
          "ptr must be 8 byte aligned");

diff  --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 980f237b19391..8465af048a280 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
 
 
 def register_python_test_dialect(context, load=True):

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99c21ff9aaef2..10a0f5bd2c6b9 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,6 +4,7 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
+from ._mlir_libs._mlir import register_type_caster
 
 
 # Convenience decorator for registering user-friendly Attribute builders.

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 72a765c75e52c..5346955c906fc 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -369,9 +369,9 @@ def __str__(self):
 
             # Classes of custom types that inherit from concrete types should have
             # static_typeid
-            assert isinstance(test.TestTensorType.static_typeid, TypeID)
+            assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
             # And it should be equal to the in-tree concrete type
-            assert test.TestTensorType.static_typeid == t.type.typeid
+            assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
 
 
 # CHECK-LABEL: TEST: inferReturnTypeComponents
@@ -424,3 +424,46 @@ def inferReturnTypeComponents():
         print("rank:", shaped_type_components.rank)
         print("element type:", shaped_type_components.element_type)
         print("shape:", shaped_type_components.shape)
+
+
+# CHECK-LABEL: TEST: testCustomTypeTypeCaster
+ at run
+def testCustomTypeTypeCaster():
+    with Context() as ctx, Location.unknown():
+        test.register_python_test_dialect(ctx)
+
+        a = test.TestType.get()
+        assert a.typeid is not None
+
+        b = Type.parse("!python_test.test_type")
+        # CHECK: !python_test.test_type
+        print(b)
+        # CHECK: TestType(!python_test.test_type)
+        print(repr(b))
+
+        c = test.TestIntegerRankedTensorType.get([10, 10], 5)
+        # CHECK: tensor<10x10xi5>
+        print(c)
+        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
+        print(repr(c))
+
+        # CHECK: Type caster is already registered
+        try:
+
+            def type_caster(pytype):
+                return test.TestIntegerRankedTensorType(pytype)
+
+            register_type_caster(c.typeid, type_caster)
+        except RuntimeError as e:
+            print(e)
+
+        def type_caster(pytype):
+            return test.TestIntegerRankedTensorType(pytype)
+
+        register_type_caster(c.typeid, type_caster, replace=True)
+
+        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
+        # CHECK: tensor<10x10xi5>
+        print(d.type)
+        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
+        print(repr(d.type))

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 29074052796b4..3ee0691d606f8 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -553,3 +553,42 @@ def testStridedLayoutAttr():
         print(f"rank: {len(attr.strides)}")
         # CHECK: strides are dynamic: [True, True, True]
         print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
+
+
+# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
+ at run
+def testConcreteTypesRoundTrip():
+    with Context(), Location.unknown():
+
+        def print_item(attr):
+            print(repr(attr.type))
+
+        # CHECK: F32Type(f32)
+        print_item(Attribute.parse("42.0 : f32"))
+        # CHECK: F32Type(f32)
+        print_item(FloatAttr.get_f32(42.0))
+        # CHECK: IntegerType(i64)
+        print_item(IntegerAttr.get(IntegerType.get_signless(64), 42))
+
+        def print_container_item(attr_asm):
+            attr = DenseElementsAttr(Attribute.parse(attr_asm))
+            print(repr(attr.type))
+            print(repr(attr.type.element_type))
+
+        # CHECK: RankedTensorType(tensor<i16>)
+        # CHECK: IntegerType(i16)
+        print_container_item("dense<123> : tensor<i16>")
+
+        # CHECK: RankedTensorType(tensor<f64>)
+        # CHECK: F64Type(f64)
+        print_container_item("dense<1.0> : tensor<f64>")
+
+        raw = Attribute.parse("vector<4xf32>")
+        # CHECK: attr: vector<4xf32>
+        print("attr:", raw)
+        type_attr = TypeAttr(raw)
+
+        # CHECK: VectorType(vector<4xf32>)
+        print(repr(type_attr.value))
+        # CHECK: F32Type(f32)
+        print(repr(type_attr.value.element_type))

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index fc484a5050839..99273bab0b495 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -2,6 +2,7 @@
 
 import gc
 from mlir.ir import *
+from mlir.dialects import arith, tensor, func, memref
 
 
 def run(f):
@@ -382,15 +383,15 @@ def testMemRefType():
         f32 = F32Type.get()
         shape = [2, 3]
         loc = Location.unknown()
-        memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
+        memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
         # CHECK: memref type: memref<2x3xf32, 2>
-        print("memref type:", memref)
+        print("memref type:", memref_f32)
         # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
-        print("memref layout:", memref.layout)
+        print("memref layout:", memref_f32.layout)
         # CHECK: memref affine map: (d0, d1) -> (d0, d1)
-        print("memref affine map:", memref.affine_map)
+        print("memref affine map:", memref_f32.affine_map)
         # CHECK: memory space: 2
-        print("memory space:", memref.memory_space)
+        print("memory space:", memref_f32.memory_space)
 
         layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
         memref_layout = MemRefType.get(shape, f32, layout=layout)
@@ -413,7 +414,7 @@ def testMemRefType():
         else:
             print("Exception not produced")
 
-        assert memref.shape == shape
+    assert memref_f32.shape == shape
 
 
 # CHECK-LABEL: TEST: testUnrankedMemRefType
@@ -482,9 +483,9 @@ def testFunctionType():
         input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
         result_types = [IndexType.get()]
         func = FunctionType.get(input_types, result_types)
-        # CHECK: INPUTS: [Type(i32), Type(i16)]
+        # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)]
         print("INPUTS:", func.inputs)
-        # CHECK: RESULTS: [Type(index)]
+        # CHECK: RESULTS: [IndexType(index)]
         print("RESULTS:", func.results)
 
 
@@ -599,3 +600,130 @@ def testTypeIDs():
         vector_type = Type.parse("vector<2x3xf32>")
         # CHECK: True
         print(ShapedType(vector_type).typeid == vector_type.typeid)
+
+
+# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
+ at run
+def testConcreteTypesRoundTrip():
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+
+        def print_downcasted(typ):
+            downcasted = Type(typ).maybe_downcast()
+            print(type(downcasted).__name__)
+            print(repr(downcasted))
+
+        # CHECK: F16Type
+        # CHECK: F16Type(f16)
+        print_downcasted(F16Type.get())
+        # CHECK: F32Type
+        # CHECK: F32Type(f32)
+        print_downcasted(F32Type.get())
+        # CHECK: F64Type
+        # CHECK: F64Type(f64)
+        print_downcasted(F64Type.get())
+        # CHECK: Float8E4M3B11FNUZType
+        # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
+        print_downcasted(Float8E4M3B11FNUZType.get())
+        # CHECK: Float8E4M3FNType
+        # CHECK: Float8E4M3FNType(f8E4M3FN)
+        print_downcasted(Float8E4M3FNType.get())
+        # CHECK: Float8E4M3FNUZType
+        # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
+        print_downcasted(Float8E4M3FNUZType.get())
+        # CHECK: Float8E5M2Type
+        # CHECK: Float8E5M2Type(f8E5M2)
+        print_downcasted(Float8E5M2Type.get())
+        # CHECK: Float8E5M2FNUZType
+        # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+        print_downcasted(Float8E5M2FNUZType.get())
+        # CHECK: BF16Type
+        # CHECK: BF16Type(bf16)
+        print_downcasted(BF16Type.get())
+        # CHECK: IndexType
+        # CHECK: IndexType(index)
+        print_downcasted(IndexType.get())
+        # CHECK: IntegerType
+        # CHECK: IntegerType(i32)
+        print_downcasted(IntegerType.get_signless(32))
+
+        f32 = F32Type.get()
+        ranked_tensor = tensor.EmptyOp([10, 10], f32).result
+        # CHECK: RankedTensorType
+        print(type(ranked_tensor.type).__name__)
+        # CHECK: RankedTensorType(tensor<10x10xf32>)
+        print(repr(ranked_tensor.type))
+
+        cf32 = ComplexType.get(f32)
+        # CHECK: ComplexType
+        print(type(cf32).__name__)
+        # CHECK: ComplexType(complex<f32>)
+        print(repr(cf32))
+
+        ranked_tensor = tensor.EmptyOp([10, 10], f32).result
+        # CHECK: RankedTensorType
+        print(type(ranked_tensor.type).__name__)
+        # CHECK: RankedTensorType(tensor<10x10xf32>)
+        print(repr(ranked_tensor.type))
+
+        vector = VectorType.get([10, 10], f32)
+        tuple_type = TupleType.get_tuple([f32, vector])
+        # CHECK: TupleType
+        print(type(tuple_type).__name__)
+        # CHECK: TupleType(tuple<f32, vector<10x10xf32>>)
+        print(repr(tuple_type))
+        # CHECK: F32Type(f32)
+        print(repr(tuple_type.get_type(0)))
+        # CHECK: VectorType(vector<10x10xf32>)
+        print(repr(tuple_type.get_type(1)))
+
+        index_type = IndexType.get()
+
+        @func.FuncOp.from_py_func()
+        def default_builder():
+            c0 = arith.ConstantOp(f32, 0.0)
+            unranked_tensor_type = UnrankedTensorType.get(f32)
+            unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result
+            # CHECK: UnrankedTensorType
+            print(type(unranked_tensor.type).__name__)
+            # CHECK: UnrankedTensorType(tensor<*xf32>)
+            print(repr(unranked_tensor.type))
+
+            c10 = arith.ConstantOp(index_type, 10)
+            memref_f32_t = MemRefType.get([10, 10], f32)
+            memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result
+            # CHECK: MemRefType
+            print(type(memref_f32.type).__name__)
+            # CHECK: MemRefType(memref<10x10xf32>)
+            print(repr(memref_f32.type))
+
+            unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2"))
+            memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result
+            # CHECK: UnrankedMemRefType
+            print(type(memref_f32.type).__name__)
+            # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
+            print(repr(memref_f32.type))
+
+            tuple_type = Operation.parse(
+                f'"test.make_tuple"() : () -> tuple<i32, f32>'
+            ).result
+            # CHECK: TupleType
+            print(type(tuple_type.type).__name__)
+            # CHECK: TupleType(tuple<i32, f32>)
+            print(repr(tuple_type.type))
+
+            return c0, c10
+
+
+# CHECK-LABEL: TEST: testCustomTypeTypeCaster
+# This tests being able to materialize a type from a dialect *and* have
+# the implemented type caster called without explicitly importing the dialect.
+# I.e., we get a transform.OperationType without explicitly importing the transform dialect.
+ at run
+def testCustomTypeTypeCaster():
+    with Context() as ctx, Location.unknown():
+        t = Type.parse('!transform.op<"foo.bar">', Context())
+        # CHECK: !transform.op<"foo.bar">
+        print(t)
+        # CHECK: OperationType(!transform.op<"foo.bar">)
+        print(repr(t))

diff  --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 7b443554440bc..71778a97d83a4 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -31,6 +31,10 @@ MlirType mlirPythonTestTestTypeGet(MlirContext context) {
   return wrap(python_test::TestTypeType::get(unwrap(context)));
 }
 
+MlirTypeID mlirPythonTestTestTypeGetTypeID(void) {
+  return wrap(python_test::TestTypeType::getTypeID());
+}
+
 bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
   return mlirTypeIsATensor(wrap(unwrap(value).getType()));
 }

diff  --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 90c5d4a383f95..5f1ed3a5b2ad6 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
 
 #ifdef __cplusplus

diff  --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index 7edeaac86e45c..f533082a0a147 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -7,11 +7,19 @@
 //===----------------------------------------------------------------------===//
 
 #include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
 namespace py = pybind11;
 using namespace mlir::python::adaptors;
+using namespace pybind11::literals;
+
+static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
+  return mlirTypeIsARankedTensor(t) &&
+         mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
+}
 
 PYBIND11_MODULE(_mlirPythonTest, m) {
   m.def(
@@ -34,16 +42,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
             return cls(mlirPythonTestTestAttributeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
-  mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType)
+  mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
+                     mlirPythonTestTestTypeGetTypeID)
       .def_classmethod(
           "get",
           [](py::object cls, MlirContext ctx) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
-  mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
-                     py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                         .attr("RankedTensorType"));
+  auto cls =
+      mlir_type_subclass(m, "TestIntegerRankedTensorType",
+                         mlirTypeIsARankedIntegerTensor,
+                         py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+                             .attr("RankedTensorType"))
+          .def_classmethod(
+              "get",
+              [](const py::object &cls, std::vector<int64_t> shape,
+                 unsigned width, MlirContext ctx) {
+                MlirAttribute encoding = mlirAttributeGetNull();
+                return cls(mlirRankedTensorTypeGet(
+                    shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
+                    encoding));
+              },
+              "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
+  assert(py::hasattr(cls.get_class(), "static_typeid") &&
+         "TestIntegerRankedTensorType has no static_typeid");
+  MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+  py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+      .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+          mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
+            return cls.get_class()(mlirType);
+          }),
+          /*replace=*/true);
   mlir_value_subclass(m, "TestTensorValue",
                       mlirTypeIsAPythonTestTestTensorValue)
       .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });


        


More information about the Mlir-commits mailing list