[Mlir-commits] [mlir] d39a784 - [MLIR][python bindings] Expose TypeIDs in python
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 22 11:21:37 PDT 2023
Author: max
Date: 2023-05-22T13:19:54-05:00
New Revision: d39a7844028bcdd28f72b0e69becc9c49b8fd283
URL: https://github.com/llvm/llvm-project/commit/d39a7844028bcdd28f72b0e69becc9c49b8fd283
DIFF: https://github.com/llvm/llvm-project/commit/d39a7844028bcdd28f72b0e69becc9c49b8fd283.diff
LOG: [MLIR][python bindings] Expose TypeIDs in python
This diff adds python bindings for `MlirTypeID`. It paves the way for returning accurately typed `Type`s from python APIs (see D150927) and then further along building type "conscious" `Value` APIs (see D150413).
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D150839
Added:
Modified:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/include/mlir-c/BuiltinTypes.h
mlir/include/mlir/Bindings/Python/PybindAdaptors.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/python/mlir/dialects/python_test.py
mlir/test/python/dialects/python_test.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/lib/PythonTestModule.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index b877f94aa48d5..6ebb458082d7c 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -80,6 +80,8 @@
#define MLIR_PYTHON_CAPSULE_PASS_MANAGER \
MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr")
#define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr")
+#define MLIR_PYTHON_CAPSULE_TYPEID \
+ MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr")
/** Attribute on MLIR Python objects that expose their C-API pointer.
* This will be a type-specific capsule created as per one of the helpers
@@ -268,6 +270,25 @@ static inline MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule) {
return op;
}
+/** Creates a capsule object encapsulating the raw C-API MlirTypeID.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the type in any way.
+ */
+static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) {
+ return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID),
+ MLIR_PYTHON_CAPSULE_TYPEID, NULL);
+}
+
+/** Extracts an MlirTypeID from a capsule as produced from
+ * mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then
+ * a null type is returned (as checked via mlirTypeIDIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID);
+ MlirTypeID typeID = {ptr};
+ return typeID;
+}
+
/** Creates a capsule object encapsulating the raw C-API MlirType.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the type in any way.
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 2b7606f3d9caf..4348c5ba167f9 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -22,6 +22,9 @@ extern "C" {
// Integer types.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Integer type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void);
+
/// Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type);
@@ -56,6 +59,9 @@ MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type);
// Index type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Index type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void);
+
/// Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type);
@@ -67,6 +73,9 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
// Floating-point types.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Float8E5M2 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
+
/// Checks whether the given type is an f8E5M2 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
@@ -74,6 +83,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
+/// Returns the typeID of an Float8E4M3FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void);
+
/// Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
@@ -81,6 +93,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
+/// Returns the typeID of an Float8E5M2FNUZ type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void);
+
/// Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
@@ -88,6 +103,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
+/// Returns the typeID of an Float8E4M3FNUZ type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void);
+
/// Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
@@ -95,6 +113,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
+/// Returns the typeID of an Float8E4M3B11FNUZ type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void);
+
/// Checks whether the given type is an f8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
@@ -102,6 +123,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
+/// Returns the typeID of an BFloat16 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
+
/// Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
@@ -109,6 +133,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx);
+/// Returns the typeID of an Float16 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void);
+
/// Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
@@ -116,6 +143,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx);
+/// Returns the typeID of an Float32 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void);
+
/// Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
@@ -123,6 +153,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx);
+/// Returns the typeID of an Float64 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void);
+
/// Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
@@ -134,6 +167,9 @@ MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
// None type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an None type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void);
+
/// Checks whether the given type is a None type.
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type);
@@ -145,6 +181,9 @@ MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx);
// Complex type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Complex type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void);
+
/// Checks whether the given type is a Complex type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type);
@@ -159,6 +198,9 @@ MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type);
// Shaped type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Shaped type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void);
+
/// Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type);
@@ -202,6 +244,9 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);
// Vector type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Vector type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void);
+
/// Checks whether the given type is a Vector type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type);
@@ -226,9 +271,15 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
/// Checks whether the given type is a Tensor type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type);
+/// Returns the typeID of an RankedTensor type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void);
+
/// Checks whether the given type is a ranked tensor type.
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type);
+/// Returns the typeID of an UnrankedTensor type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void);
+
/// Checks whether the given type is an unranked tensor type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type);
@@ -264,9 +315,15 @@ mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType);
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an MemRef type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void);
+
/// Checks whether the given type is a MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type);
+/// Returns the typeID of an UnrankedMemRef type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void);
+
/// Checks whether the given type is an UnrankedMemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
@@ -326,6 +383,9 @@ mlirUnrankedMemrefGetMemorySpace(MlirType type);
// Tuple type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Tuple type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void);
+
/// Checks whether the given type is a tuple type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type);
@@ -345,6 +405,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
// Function type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Function type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void);
+
/// Checks whether the given type is a function type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type);
@@ -373,6 +436,9 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type,
// Opaque type.
//===----------------------------------------------------------------------===//
+/// Returns the typeID of an Opaque type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void);
+
/// Checks whether the given type is an opaque type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type);
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index bec3fc76e39d2..ccca3aa0172e6 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -236,6 +236,27 @@ struct type_caster<MlirPassManager> {
}
};
+/// Casts object <-> MlirTypeID.
+template <>
+struct type_caster<MlirTypeID> {
+ PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToTypeID(capsule.ptr());
+ return !mlirTypeIDIsNull(value);
+ }
+ static handle cast(MlirTypeID v, return_value_policy, handle) {
+ if (v.ptr == nullptr)
+ return py::none();
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("TypeID")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
/// Casts object <-> MlirType.
template <>
struct type_caster<MlirType> {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7ffa464009fc8..db8390abee925 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -17,6 +17,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -1807,6 +1808,24 @@ PyType PyType::createFromCapsule(py::object capsule) {
rawType);
}
+//------------------------------------------------------------------------------
+// PyTypeID.
+//------------------------------------------------------------------------------
+
+py::object PyTypeID::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
+}
+
+PyTypeID PyTypeID::createFromCapsule(py::object capsule) {
+ MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
+ if (mlirTypeIDIsNull(mlirTypeID))
+ throw py::error_already_set();
+ return PyTypeID(mlirTypeID);
+}
+bool PyTypeID::operator==(const PyTypeID &other) const {
+ return mlirTypeIDEqual(typeID, other.typeID);
+}
+
//------------------------------------------------------------------------------
// PyValue and subclases.
//------------------------------------------------------------------------------
@@ -3268,16 +3287,47 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join();
},
"Returns the assembly form of the type.")
- .def("__repr__", [](PyType &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, types are an exception as they typically have compact
- // assembly forms and printing them is useful.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Type(");
- mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
+ .def("__repr__",
+ [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ })
+ .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ if (!mlirTypeIDIsNull(mlirTypeID))
+ return mlirTypeID;
+ auto origRepr =
+ pybind11::repr(pybind11::cast(self)).cast<std::string>();
+ throw py::value_error(
+ (origRepr + llvm::Twine(" has no typeid.")).str());
+ });
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyTypeID.
+ //----------------------------------------------------------------------------
+ py::class_<PyTypeID>(m, "TypeID", py::module_local())
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
+ // Note, this tests whether the underlying TypeIDs are the same,
+ // not whether the wrapper MlirTypeIDs are the same, nor whether
+ // the Python objects are the same (i.e., PyTypeID is a value type).
+ .def("__eq__",
+ [](PyTypeID &self, PyTypeID &other) { return self == other; })
+ .def("__eq__",
+ [](PyTypeID &self, const py::object &other) { return false; })
+ // Note, this gives the hash value of the underlying TypeID, not the
+ // hash value of the Python object, nor the hash value of the
+ // MlirTypeID wrapper.
+ .def("__hash__", [](PyTypeID &self) {
+ return static_cast<size_t>(mlirTypeIDHashValue(self));
});
//----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ade790ba0ed13..fa529c43444d3 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -20,6 +20,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@@ -826,6 +827,29 @@ class PyType : public BaseContextObject {
MlirType type;
};
+/// A TypeID provides an efficient and unique identifier for a specific C++
+/// type. This allows for a C++ type to be compared, hashed, and stored in an
+/// opaque context. This class wraps around the generic MlirTypeID.
+class PyTypeID {
+public:
+ PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
+ // Note, this tests whether the underlying TypeIDs are the same,
+ // not whether the wrapper MlirTypeIDs are the same, nor whether
+ // the PyTypeID objects are the same (i.e., PyTypeID is a value type).
+ bool operator==(const PyTypeID &other) const;
+ operator MlirTypeID() const { return typeID; }
+ MlirTypeID get() { return typeID; }
+
+ /// Gets a capsule wrapping the void* within the MlirTypeID.
+ pybind11::object getCapsule();
+
+ /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
+ static PyTypeID createFromCapsule(pybind11::object capsule);
+
+private:
+ MlirTypeID typeID;
+};
+
/// CRTP base classes for Python types that subclass Type and should be
/// castable from it (i.e. via something like IntegerType(t)).
/// By default, type class hierarchies are one level deep (i.e. a
@@ -839,10 +863,14 @@ class PyConcreteType : public BaseTy {
// const char *pyClassName
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirType);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
- : BaseTy(std::move(contextRef), t) {}
+ : BaseTy(std::move(contextRef), t) {
+ pybind11::implicitly_convertible<PyType, DerivedTy>();
+ }
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
@@ -866,6 +894,26 @@ class PyConcreteType : public BaseTy {
return DerivedTy::isaFunction(otherType);
},
pybind11::arg("other"));
+ cls.def_property_readonly_static(
+ "static_typeid", [](py::object & /*class*/) -> MlirTypeID {
+ if (DerivedTy::getTypeIdFunction)
+ return DerivedTy::getTypeIdFunction();
+ throw SetPyError(PyExc_AttributeError,
+ DerivedTy::pyClassName +
+ llvm::Twine(" has no typeid."));
+ });
+ cls.def_property_readonly("typeid", [](PyType &self) {
+ return py::cast(self).attr("typeid").cast<MlirTypeID>();
+ });
+ cls.def("__repr__", [](DerivedTy &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append(DerivedTy::pyClassName);
+ printAccum.parts.append("(");
+ mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ });
+
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index cb62a402dc671..f45b30250c174 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -32,6 +32,8 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
class PyIntegerType : public PyConcreteType<PyIntegerType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerTypeGetTypeID;
static constexpr const char *pyClassName = "IntegerType";
using PyConcreteType::PyConcreteType;
@@ -89,6 +91,8 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
class PyIndexType : public PyConcreteType<PyIndexType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIndexTypeGetTypeID;
static constexpr const char *pyClassName = "IndexType";
using PyConcreteType::PyConcreteType;
@@ -107,6 +111,8 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3FNType";
using PyConcreteType::PyConcreteType;
@@ -125,6 +131,8 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2TypeGetTypeID;
static constexpr const char *pyClassName = "Float8E5M2Type";
using PyConcreteType::PyConcreteType;
@@ -143,6 +151,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3FNUZType";
using PyConcreteType::PyConcreteType;
@@ -161,6 +171,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3B11FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
using PyConcreteType::PyConcreteType;
@@ -179,6 +191,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E5M2FNUZType";
using PyConcreteType::PyConcreteType;
@@ -197,6 +211,8 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirBFloat16TypeGetTypeID;
static constexpr const char *pyClassName = "BF16Type";
using PyConcreteType::PyConcreteType;
@@ -215,6 +231,8 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
class PyF16Type : public PyConcreteType<PyF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat16TypeGetTypeID;
static constexpr const char *pyClassName = "F16Type";
using PyConcreteType::PyConcreteType;
@@ -233,6 +251,8 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
class PyF32Type : public PyConcreteType<PyF32Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat32TypeGetTypeID;
static constexpr const char *pyClassName = "F32Type";
using PyConcreteType::PyConcreteType;
@@ -251,6 +271,8 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
class PyF64Type : public PyConcreteType<PyF64Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat64TypeGetTypeID;
static constexpr const char *pyClassName = "F64Type";
using PyConcreteType::PyConcreteType;
@@ -269,6 +291,8 @@ class PyF64Type : public PyConcreteType<PyF64Type> {
class PyNoneType : public PyConcreteType<PyNoneType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirNoneTypeGetTypeID;
static constexpr const char *pyClassName = "NoneType";
using PyConcreteType::PyConcreteType;
@@ -287,6 +311,8 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
class PyComplexType : public PyConcreteType<PyComplexType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirComplexTypeGetTypeID;
static constexpr const char *pyClassName = "ComplexType";
using PyConcreteType::PyConcreteType;
@@ -417,6 +443,8 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirVectorTypeGetTypeID;
static constexpr const char *pyClassName = "VectorType";
using PyConcreteType::PyConcreteType;
@@ -442,6 +470,8 @@ class PyRankedTensorType
: public PyConcreteType<PyRankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "RankedTensorType";
using PyConcreteType::PyConcreteType;
@@ -476,6 +506,8 @@ class PyUnrankedTensorType
: public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "UnrankedTensorType";
using PyConcreteType::PyConcreteType;
@@ -498,6 +530,8 @@ class PyUnrankedTensorType
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirMemRefTypeGetTypeID;
static constexpr const char *pyClassName = "MemRefType";
using PyConcreteType::PyConcreteType;
@@ -550,6 +584,8 @@ class PyUnrankedMemRefType
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedMemRefTypeGetTypeID;
static constexpr const char *pyClassName = "UnrankedMemRefType";
using PyConcreteType::PyConcreteType;
@@ -585,6 +621,8 @@ class PyUnrankedMemRefType
class PyTupleType : public PyConcreteType<PyTupleType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTupleTypeGetTypeID;
static constexpr const char *pyClassName = "TupleType";
using PyConcreteType::PyConcreteType;
@@ -622,6 +660,8 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
class PyFunctionType : public PyConcreteType<PyFunctionType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFunctionTypeGetTypeID;
static constexpr const char *pyClassName = "FunctionType";
using PyConcreteType::PyConcreteType;
@@ -676,6 +716,8 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueTypeGetTypeID;
static constexpr const char *pyClassName = "OpaqueType";
using PyConcreteType::PyConcreteType;
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 90ab847606ee0..1925478c66d41 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -22,6 +22,8 @@ using namespace mlir;
// Integer types.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }
+
bool mlirTypeIsAInteger(MlirType type) {
return llvm::isa<IntegerType>(unwrap(type));
}
@@ -58,6 +60,8 @@ bool mlirIntegerTypeIsUnsigned(MlirType type) {
// Index type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }
+
bool mlirTypeIsAIndex(MlirType type) {
return llvm::isa<IndexType>(unwrap(type));
}
@@ -70,6 +74,10 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
// Floating-point types.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
+ return wrap(Float8E5M2Type::getTypeID());
+}
+
bool mlirTypeIsAFloat8E5M2(MlirType type) {
return unwrap(type).isFloat8E5M2();
}
@@ -78,6 +86,10 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
}
+MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
+ return wrap(Float8E4M3FNType::getTypeID());
+}
+
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
return unwrap(type).isFloat8E4M3FN();
}
@@ -86,6 +98,10 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
}
+MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
+ return wrap(Float8E5M2FNUZType::getTypeID());
+}
+
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
return unwrap(type).isFloat8E5M2FNUZ();
}
@@ -94,6 +110,10 @@ MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
}
+MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
+ return wrap(Float8E4M3FNUZType::getTypeID());
+}
+
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3FNUZ();
}
@@ -102,6 +122,10 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
}
+MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
+ return wrap(Float8E4M3B11FNUZType::getTypeID());
+}
+
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3B11FNUZ();
}
@@ -110,24 +134,34 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
}
+MlirTypeID mlirBFloat16TypeGetTypeID() {
+ return wrap(BFloat16Type::getTypeID());
+}
+
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getBF16(unwrap(ctx)));
}
+MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
+
bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getF16(unwrap(ctx)));
}
+MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
+
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(FloatType::getF32(unwrap(ctx)));
}
+MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
+
bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
MlirType mlirF64TypeGet(MlirContext ctx) {
@@ -138,6 +172,8 @@ MlirType mlirF64TypeGet(MlirContext ctx) {
// None type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }
+
bool mlirTypeIsANone(MlirType type) {
return llvm::isa<NoneType>(unwrap(type));
}
@@ -150,6 +186,8 @@ MlirType mlirNoneTypeGet(MlirContext ctx) {
// Complex type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }
+
bool mlirTypeIsAComplex(MlirType type) {
return llvm::isa<ComplexType>(unwrap(type));
}
@@ -214,6 +252,8 @@ int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
// Vector type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }
+
bool mlirTypeIsAVector(MlirType type) {
return llvm::isa<VectorType>(unwrap(type));
}
@@ -239,10 +279,18 @@ bool mlirTypeIsATensor(MlirType type) {
return llvm::isa<TensorType>(unwrap(type));
}
+MlirTypeID mlirRankedTensorTypeGetTypeID() {
+ return wrap(RankedTensorType::getTypeID());
+}
+
bool mlirTypeIsARankedTensor(MlirType type) {
return llvm::isa<RankedTensorType>(unwrap(type));
}
+MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
+ return wrap(UnrankedTensorType::getTypeID());
+}
+
bool mlirTypeIsAUnrankedTensor(MlirType type) {
return llvm::isa<UnrankedTensorType>(unwrap(type));
}
@@ -280,6 +328,8 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }
+
bool mlirTypeIsAMemRef(MlirType type) {
return llvm::isa<MemRefType>(unwrap(type));
}
@@ -337,6 +387,10 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
+MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
+ return wrap(UnrankedMemRefType::getTypeID());
+}
+
bool mlirTypeIsAUnrankedMemRef(MlirType type) {
return llvm::isa<UnrankedMemRefType>(unwrap(type));
}
@@ -362,6 +416,8 @@ MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
// Tuple type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }
+
bool mlirTypeIsATuple(MlirType type) {
return llvm::isa<TupleType>(unwrap(type));
}
@@ -386,6 +442,10 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
// Function type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirFunctionTypeGetTypeID() {
+ return wrap(FunctionType::getTypeID());
+}
+
bool mlirTypeIsAFunction(MlirType type) {
return llvm::isa<FunctionType>(unwrap(type));
}
@@ -424,6 +484,8 @@ MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
// Opaque type.
//===----------------------------------------------------------------------===//
+MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }
+
bool mlirTypeIsAOpaque(MlirType type) {
return llvm::isa<OpaqueType>(unwrap(type));
}
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 5d42ddc47a242..ca0d479f1f5fc 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 6cde96e1da10d..2ca79b29f567c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -299,6 +299,14 @@ def testCustomType():
# The following cast must not assert.
b = test.TestType(a)
+ # Instance custom types should have typeids
+ assert isinstance(b.typeid, TypeID)
+ # Subclasses of ir.Type should not have a static_typeid
+ # CHECK: 'TestType' object has no attribute 'static_typeid'
+ try:
+ b.static_typeid
+ except AttributeError as e:
+ print(e)
i8 = IntegerType.get_signless(8)
try:
@@ -353,6 +361,12 @@ def __str__(self):
# CHECK: False
print(tt.is_null())
+ # Classes of custom types that inherit from concrete types should have
+ # static_typeid
+ assert isinstance(test.TestTensorType.static_typeid, TypeID)
+ # And it should be equal to the in-tree concrete type
+ assert test.TestTensorType.static_typeid == t.type.typeid
+
# CHECK-LABEL: TEST: inferReturnTypeComponents
@run
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index e383a78f40b8a..19e21fff8dba0 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
+
def run(f):
print("\nTEST:", f.__name__)
f()
@@ -76,6 +77,7 @@ def testTypeHash():
# CHECK: len(s): 2
print("len(s): ", len(s))
+
# CHECK-LABEL: TEST: testTypeCast
@run
def testTypeCast():
@@ -182,6 +184,7 @@ def testIntegerType():
# CHECK: unsigned: ui64
print("unsigned:", IntegerType.get_unsigned(64))
+
# CHECK-LABEL: TEST: testIndexType
@run
def testIndexType():
@@ -259,7 +262,8 @@ def testConcreteShapedType():
# CHECK: rank: 2
print("rank:", vector.rank)
# CHECK: whether the shaped type has a static shape: True
- print("whether the shaped type has a static shape:", vector.has_static_shape)
+ print("whether the shaped type has a static shape:",
+ vector.has_static_shape)
# CHECK: whether the dim-th dimension is dynamic: False
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
# CHECK: dim size: 3
@@ -311,8 +315,7 @@ def testRankedTensorType():
shape = [2, 3]
loc = Location.unknown()
# CHECK: ranked tensor type: tensor<2x3xf32>
- print("ranked tensor type:",
- RankedTensorType.get(shape, f32))
+ print("ranked tensor type:", RankedTensorType.get(shape, f32))
none = NoneType.get()
try:
@@ -477,8 +480,7 @@ def testTupleType():
@run
def testFunctionType():
with Context() as ctx:
- input_types = [IntegerType.get_signless(32),
- IntegerType.get_signless(16)]
+ input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
result_types = [IndexType.get()]
func = FunctionType.get(input_types, result_types)
# CHECK: INPUTS: [Type(i32), Type(i16)]
@@ -509,3 +511,91 @@ def testShapedTypeConstants():
print(type(ShapedType.get_dynamic_size()))
# CHECK: <class 'int'>
print(type(ShapedType.get_dynamic_stride_or_offset()))
+
+
+# CHECK-LABEL: TEST: testTypeIDs
+ at run
+def testTypeIDs():
+ with Context(), Location.unknown():
+ f32 = F32Type.get()
+
+ types = [
+ (IntegerType, IntegerType.get_signless(16)),
+ (IndexType, IndexType.get()),
+ (Float8E4M3FNType, Float8E4M3FNType.get()),
+ (Float8E5M2Type, Float8E5M2Type.get()),
+ (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
+ (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
+ (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
+ (BF16Type, BF16Type.get()),
+ (F16Type, F16Type.get()),
+ (F32Type, F32Type.get()),
+ (F64Type, F64Type.get()),
+ (NoneType, NoneType.get()),
+ (ComplexType, ComplexType.get(f32)),
+ (VectorType, VectorType.get([2, 3], f32)),
+ (RankedTensorType, RankedTensorType.get([2, 3], f32)),
+ (UnrankedTensorType, UnrankedTensorType.get(f32)),
+ (MemRefType, MemRefType.get([2, 3], f32)),
+ (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
+ (TupleType, TupleType.get_tuple([f32])),
+ (FunctionType, FunctionType.get([], [])),
+ (OpaqueType, OpaqueType.get("tensor", "bob")),
+ ]
+
+ # CHECK: IntegerType(i16)
+ # CHECK: IndexType(index)
+ # CHECK: Float8E4M3FNType(f8E4M3FN)
+ # CHECK: Float8E5M2Type(f8E5M2)
+ # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
+ # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
+ # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+ # CHECK: BF16Type(bf16)
+ # CHECK: F16Type(f16)
+ # CHECK: F32Type(f32)
+ # CHECK: F64Type(f64)
+ # CHECK: NoneType(none)
+ # CHECK: ComplexType(complex<f32>)
+ # CHECK: VectorType(vector<2x3xf32>)
+ # CHECK: RankedTensorType(tensor<2x3xf32>)
+ # CHECK: UnrankedTensorType(tensor<*xf32>)
+ # CHECK: MemRefType(memref<2x3xf32>)
+ # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
+ # CHECK: TupleType(tuple<f32>)
+ # CHECK: FunctionType(() -> ())
+ # CHECK: OpaqueType(!tensor.bob)
+ for _, t in types:
+ print(repr(t))
+
+ # Test getTypeIdFunction agrees with
+ # mlirTypeGetTypeID(self) for an instance.
+ # CHECK: all equal
+ for t1, t2 in types:
+ tid1, tid2 = t1.static_typeid, Type(t2).typeid
+ assert tid1 == tid2 and hash(tid1) == hash(
+ tid2), f"expected hash and value equality {t1} {t2}"
+ else:
+ print("all equal")
+
+ # Test that storing PyTypeID in python dicts
+ # works as expected.
+ typeid_dict = dict(types)
+ assert len(typeid_dict)
+
+ # CHECK: all equal
+ for t1, t2 in typeid_dict.items():
+ assert t1.static_typeid == t2.typeid and hash(
+ t1.static_typeid) == hash(
+ t2.typeid), f"expected hash and value equality {t1} {t2}"
+ else:
+ print("all equal")
+
+ # CHECK: ShapedType has no typeid.
+ try:
+ print(ShapedType.static_typeid)
+ except AttributeError as e:
+ print(e)
+
+ vector_type = Type.parse("vector<2x3xf32>")
+ # CHECK: True
+ print(ShapedType(vector_type).typeid == vector_type.typeid)
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f17f0821599c5..7edeaac86e45c 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
@@ -40,6 +41,9 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
+ mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("RankedTensorType"));
mlir_value_subclass(m, "TestTensorValue",
mlirTypeIsAPythonTestTestTensorValue)
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
More information about the Mlir-commits
mailing list