[llvm-branch-commits] [mlir] [mlir][Python] move IRTypes and IRAttributes to public headers (PR #173939)
Maksim Levental via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Dec 29 16:59:38 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/173939
>From dc24520fc192e2774509f15093d815910aeec1f4 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 29 Dec 2025 16:57:05 -0800
Subject: [PATCH] [mlir][Python] move IRTypes and IRAttributes to public
headers
---
mlir/include/mlir/Bindings/Python/IRCore.h | 17 +-
mlir/include/mlir/Bindings/Python/IRTypes.h | 465 ++++-
mlir/lib/Bindings/Python/IRTypes.cpp | 1573 +++++++----------
mlir/python/CMakeLists.txt | 4 +-
.../python/lib/PythonTestModuleNanobind.cpp | 129 +-
5 files changed, 1138 insertions(+), 1050 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 0f402b4ce15ff..340b16bcdf558 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -979,7 +979,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
nanobind::cast<nanobind::callable>(nanobind::cpp_function(
- [](PyType pyType) -> DerivedTy { return pyType; })));
+ [](PyType pyType) -> DerivedTy { return pyType; })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1123,7 +1124,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
nanobind::cast<nanobind::callable>(
nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
- })));
+ })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1511,6 +1513,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
// and redefine bindDerived.
using ClassTy = nanobind::class_<DerivedTy, PyValue>;
using IsAFunctionTy = bool (*)(MlirValue);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef, MlirValue value)
@@ -1553,6 +1557,15 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
return self.maybeDownCast();
});
+
+ if (DerivedTy::getTypeIdFunction) {
+ PyGlobals::get().registerValueCaster(
+ DerivedTy::getTypeIdFunction(),
+ nanobind::cast<nanobind::callable>(nanobind::cpp_function(
+ [](PyValue pyValue) -> DerivedTy { return pyValue; })),
+ /*replace*/ true);
+ }
+
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 87e0e10764bd8..db478e8d33f37 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,13 +9,14 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
+#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Shaped Type Interface - ShapedType
-class MLIR_PYTHON_API_EXPORTED PyShapedType
+class MLIR_PYTHON_API_EXPORTED MLIR_PYTHON_API_EXPORTED PyShapedType
: public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
@@ -27,6 +28,468 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
private:
void requireHasRank();
};
+
+/// Checks whether the given type is an integer or float type.
+inline int mlirTypeIsAIntegerOrFloat(MlirType type) {
+ return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
+ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
+}
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerType
+ : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerTypeGetTypeID;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Index Type subclass - IndexType.
+class MLIR_PYTHON_API_EXPORTED PyIndexType
+ : public PyConcreteType<PyIndexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIndexTypeGetTypeID;
+ static constexpr const char *pyClassName = "IndexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFloatType
+ : public PyConcreteType<PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+ static constexpr const char *pyClassName = "FloatType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float4E2M1FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat4E2M1FNType
+ : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat4E2M1FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float4E2M1FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E2M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E2M3FNType
+ : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E2M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E2M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E3M2FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E3M2FNType
+ : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E3M2FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E3M2FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNType
+ : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2Type
+ : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3Type
+ : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNUZType
+ : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3B11FNUZType
+ : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3B11FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2FNUZType
+ : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E3M4Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E3M4Type
+ : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E3M4TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E3M4Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E8M0FNUType
+ : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E8M0FNUTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E8M0FNUType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - BF16Type.
+class MLIR_PYTHON_API_EXPORTED PyBF16Type
+ : public PyConcreteType<PyBF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirBFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "BF16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F16Type.
+class MLIR_PYTHON_API_EXPORTED PyF16Type
+ : public PyConcreteType<PyF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "F16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - TF32Type.
+class MLIR_PYTHON_API_EXPORTED PyTF32Type
+ : public PyConcreteType<PyTF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatTF32TypeGetTypeID;
+ static constexpr const char *pyClassName = "FloatTF32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F32Type.
+class MLIR_PYTHON_API_EXPORTED PyF32Type
+ : public PyConcreteType<PyF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat32TypeGetTypeID;
+ static constexpr const char *pyClassName = "F32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F64Type.
+class MLIR_PYTHON_API_EXPORTED PyF64Type
+ : public PyConcreteType<PyF64Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat64TypeGetTypeID;
+ static constexpr const char *pyClassName = "F64Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// None Type subclass - NoneType.
+class MLIR_PYTHON_API_EXPORTED PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirNoneTypeGetTypeID;
+ static constexpr const char *pyClassName = "NoneType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Complex Type subclass - ComplexType.
+class MLIR_PYTHON_API_EXPORTED PyComplexType
+ : public PyConcreteType<PyComplexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirComplexTypeGetTypeID;
+ static constexpr const char *pyClassName = "ComplexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Vector Type subclass - VectorType.
+class MLIR_PYTHON_API_EXPORTED PyVectorType
+ : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirVectorTypeGetTypeID;
+ static constexpr const char *pyClassName = "VectorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
+ if (scalable && scalableDims) {
+ throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nanobind::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(
+ llvm::map_range(*scalable, [](const nanobind::handle &h) {
+ return nanobind::cast<bool>(h);
+ }));
+ type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+ scalableDimFlags.data(),
+ elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nanobind::value_error(
+ "Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+ scalableDimFlags.data(),
+ elementType);
+ } else {
+ type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+ elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nanobind::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(
+ llvm::map_range(*scalable, [](const nanobind::handle &h) {
+ return nanobind::cast<bool>(h);
+ }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nanobind::value_error(
+ "Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyRankedTensorType
+ : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "RankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedTensorType
+ : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class MLIR_PYTHON_API_EXPORTED PyMemRefType
+ : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "MemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedMemRefType
+ : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedMemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Tuple Type subclass - TupleType.
+class MLIR_PYTHON_API_EXPORTED PyTupleType
+ : public PyConcreteType<PyTupleType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTupleTypeGetTypeID;
+ static constexpr const char *pyClassName = "TupleType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Function type.
+class MLIR_PYTHON_API_EXPORTED PyFunctionType
+ : public PyConcreteType<PyFunctionType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFunctionTypeGetTypeID;
+ static constexpr const char *pyClassName = "FunctionType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Opaque Type subclass - OpaqueType.
+class MLIR_PYTHON_API_EXPORTED PyOpaqueType
+ : public PyConcreteType<PyOpaqueType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueTypeGetTypeID;
+ static constexpr const char *pyClassName = "OpaqueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7350046f428c7..951486b818a4e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -28,492 +28,6 @@ using llvm::Twine;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
- return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
- mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
-}
-
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerTypeGetTypeID;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create an unsigned integer type");
- c.def_prop_ro(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_prop_ro(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_prop_ro(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_prop_ro(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIndexTypeGetTypeID;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a index type.");
- }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
- static constexpr const char *pyClassName = "FloatType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
- "Returns the width of the floating-point type");
- }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
- : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat4E2M1FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float4E2M1FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
- return PyFloat4E2M1FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
- : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E2M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E2M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
- return PyFloat6E2M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
- : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E3M2FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E3M2FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
- return PyFloat6E3M2FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
- : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
- return PyFloat8E4M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2TypeGet(context->get());
- return PyFloat8E5M2Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3TypeGet(context->get());
- return PyFloat8E4M3Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
- : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
- return PyFloat8E4M3FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
- : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3B11FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
- return PyFloat8E4M3B11FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
- : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
- return PyFloat8E5M2FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E3M4TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E3M4Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E3M4TypeGet(context->get());
- return PyFloat8E3M4Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
- : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E8M0FNUTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E8M0FNUType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
- return PyFloat8E8M0FNUType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
- }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirBFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a bf16 type.");
- }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f16 type.");
- }
-};
-
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatTF32TypeGetTypeID;
- static constexpr const char *pyClassName = "FloatTF32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirTF32TypeGet(context->get());
- return PyTF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a tf32 type.");
- }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat32TypeGetTypeID;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f32 type.");
- }
-};
-
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat64TypeGetTypeID;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f64 type.");
- }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirNoneTypeGetTypeID;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirNoneTypeGet(context->get());
- return PyNoneType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a none type.");
- }
-};
-
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirComplexTypeGetTypeID;
- static constexpr const char *pyClassName = "ComplexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType) {
- // The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType)) {
- MlirType t = mlirComplexTypeGet(elementType);
- return PyComplexType(elementType.getContext(), t);
- }
- throw nb::value_error(
- (Twine("invalid '") +
- nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
- "' and expected floating point or integer type.")
- .str()
- .c_str());
- },
- "Create a complex type");
- c.def_prop_ro(
- "element_type",
- [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
- .maybeDownCast();
- },
- "Returns element type.");
- }
-};
-
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
// Shaped Type Interface - ShapedType
void PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
@@ -627,521 +141,632 @@ void PyShapedType::requireHasRank() {
}
}
-const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
+void PyIntegerType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_signless",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+ "Create a signless integer type");
+ c.def_static(
+ "get_signed",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+ "Create a signed integer type");
+ c.def_static(
+ "get_unsigned",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+ "Create an unsigned integer type");
+ c.def_prop_ro(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+ "Returns the width of the integer type");
+ c.def_prop_ro(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self);
+ },
+ "Returns whether this is a signless integer");
+ c.def_prop_ro(
+ "is_signed",
+ [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); },
+ "Returns whether this is a signed integer");
+ c.def_prop_ro(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self);
+ },
+ "Returns whether this is an unsigned integer");
+}
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void PyIndexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirIndexTypeGet(context->get());
+ return PyIndexType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a index type.");
+}
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirVectorTypeGetTypeID;
- static constexpr const char *pyClassName = "VectorType";
- using PyConcreteType::PyConcreteType;
+void PyFloatType::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a vector type")
- .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("context") = nb::none(), "Create a vector type")
- .def_prop_ro(
- "scalable",
- [](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_prop_ro("scalable_dims", [](MlirType self) {
- std::vector<bool> scalableDims;
- size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
- scalableDims.reserve(rank);
- for (size_t i = 0; i < rank; ++i)
- scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
- return scalableDims;
- });
- }
+void PyFloat4E2M1FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+ return PyFloat4E2M1FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float4_e2m1fn type.");
+}
-private:
- static PyVectorType
- getChecked(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
- }
+void PyFloat6E2M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+ return PyFloat6E2M3FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float6_e2m3fn type.");
+}
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
+void PyFloat6E3M2FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+ return PyFloat6E3M2FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float6_e3m2fn type.");
+}
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else {
- type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- }
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
+void PyFloat8E4M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+ return PyFloat8E4M3FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3fn type.");
+}
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyMlirContext context) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
- }
+void PyFloat8E5M2Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2TypeGet(context->get());
+ return PyFloat8E5M2Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e5m2 type.");
+}
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
+void PyFloat8E4M3Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3TypeGet(context->get());
+ return PyFloat8E4M3Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3 type.");
+}
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else {
- type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
- }
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
-};
+void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+ return PyFloat8E4M3FNUZType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3fnuz type.");
+}
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
- : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirRankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "RankedTensorType";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+ return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3b11fnuz type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirRankedTensorTypeGetChecked(
- loc, shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
- "Create a ranked tensor type");
- c.def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirRankedTensorTypeGet(
- shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
- "Create a ranked tensor type");
- c.def_prop_ro(
- "encoding",
- [](PyRankedTensorType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
- if (mlirAttributeIsNull(encoding))
- return std::nullopt;
- return PyAttribute(self.getContext(), encoding).maybeDownCast();
- });
- }
-};
+void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+ return PyFloat8E5M2FNUZType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e5m2fnuz type.");
+}
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
- : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedTensorType";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E3M4Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E3M4TypeGet(context->get());
+ return PyFloat8E3M4Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e3m4 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("loc") = nb::none(),
- "Create a unranked tensor type");
- c.def_static(
- "get_unchecked",
- [](PyType &elementType, DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirUnrankedTensorTypeGet(elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("context") = nb::none(),
- "Create a unranked tensor type");
- }
-};
+void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+ return PyFloat8E8M0FNUType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e8m0fnu type.");
+}
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "MemRefType";
- using PyConcreteType::PyConcreteType;
+void PyBF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirBF16TypeGet(context->get());
+ return PyBF16Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a bf16 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
- shape.data(), layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a memref type")
- .def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute layoutAttr =
- layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
- layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(),
- nb::arg("memory_space") = nb::none(),
- nb::arg("context") = nb::none(), "Create a memref type")
- .def_prop_ro(
- "layout",
- [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
- return PyAttribute(self.getContext(),
- mlirMemRefTypeGetLayout(self))
- .maybeDownCast();
- },
- "The layout of the MemRef type.")
- .def(
- "get_strides_and_offset",
- [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
- std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
- int64_t offset;
- if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
- self, strides.data(), &offset)))
- throw std::runtime_error(
- "Failed to extract strides and offset from memref.");
- return {strides, offset};
- },
- "The strides and offset of the MemRef type.")
- .def_prop_ro(
- "affine_map",
- [](PyMemRefType &self) -> PyAffineMap {
- MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
- return PyAffineMap(self.getContext(), map);
- },
- "The layout of the MemRef type as an affine map.")
- .def_prop_ro(
- "memory_space",
- [](PyMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given MemRef type.");
- }
-};
+void PyF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF16TypeGet(context->get());
+ return PyF16Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a f16 type.");
+}
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
- : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedMemRefType";
- using PyConcreteType::PyConcreteType;
+void PyTF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirTF32TypeGet(context->get());
+ return PyTF32Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a tf32 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF32TypeGet(context->get());
+ return PyF32Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a f32 type.");
+}
- MlirType t =
- mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("loc") = nb::none(), "Create a unranked memref type")
- .def_static(
- "get_unchecked",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyF64Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF64TypeGet(context->get());
+ return PyF64Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a f64 type.");
+}
- MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("context") = nb::none(), "Create a unranked memref type")
- .def_prop_ro(
- "memory_space",
- [](PyUnrankedMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given Unranked MemRef type.");
- }
-};
+void PyNoneType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirNoneTypeGet(context->get());
+ return PyNoneType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a none type.");
+}
+
+void PyComplexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ MlirType t = mlirComplexTypeGet(elementType);
+ return PyComplexType(elementType.getContext(), t);
+ }
+ throw nanobind::value_error(
+ (Twine("invalid '") +
+ nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(elementType))) +
+ "' and expected floating point or integer type.")
+ .str()
+ .c_str());
+ },
+ "Create a complex type");
+ c.def_prop_ro(
+ "element_type",
+ [](PyComplexType &self) -> nanobind::typed<nanobind::object, PyType> {
+ return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+ .maybeDownCast();
+ },
+ "Returns element type.");
+}
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTupleTypeGetTypeID;
- static constexpr const char *pyClassName = "TupleType";
- using PyConcreteType::PyConcreteType;
+void PyVectorType::bindDerived(ClassTy &c) {
+ c.def_static("get", &PyVectorType::getChecked, nanobind::arg("shape"),
+ nanobind::arg("element_type"), nanobind::kw_only(),
+ nanobind::arg("scalable") = nanobind::none(),
+ nanobind::arg("scalable_dims") = nanobind::none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a vector type")
+ .def_static("get_unchecked", &PyVectorType::get, nanobind::arg("shape"),
+ nanobind::arg("element_type"), nanobind::kw_only(),
+ nanobind::arg("scalable") = nanobind::none(),
+ nanobind::arg("scalable_dims") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a vector type")
+ .def_prop_ro("scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_prop_ro("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_tuple",
- [](const std::vector<PyType> &elements,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirElements;
- mlirElements.reserve(elements.size());
- for (const auto &element : elements)
- mlirElements.push_back(element.get());
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- mlirElements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- "Create a tuple type");
- c.def_static(
- "get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- elements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
- // clang-format on
- "Create a tuple type");
- c.def(
- "get_type",
- [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
- .maybeDownCast();
- },
- nb::arg("pos"), "Returns the pos-th type in the tuple type.");
- c.def_prop_ro(
- "num_types",
- [](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self);
- },
- "Returns the number of types contained in a tuple.");
- }
-};
+void PyRankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirRankedTensorTypeGetChecked(
+ loc, shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("encoding") = nanobind::none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a ranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("encoding") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a ranked tensor type");
+ c.def_prop_ro(
+ "encoding",
+ [](PyRankedTensorType &self)
+ -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> {
+ MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+ if (mlirAttributeIsNull(encoding))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), encoding).maybeDownCast();
+ });
+}
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFunctionTypeGetTypeID;
- static constexpr const char *pyClassName = "FunctionType";
- using PyConcreteType::PyConcreteType;
+void PyUnrankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"), nanobind::arg("loc") = nanobind::none(),
+ "Create a unranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a unranked tensor type");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirInputs;
- mlirInputs.reserve(inputs.size());
- for (const auto &input : inputs)
- mlirInputs.push_back(input.get());
- std::vector<MlirType> mlirResults;
- mlirResults.reserve(results.size());
- for (const auto &result : results)
- mlirResults.push_back(result.get());
+void PyMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout,
+ PyAttribute *memorySpace, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+ shape.data(), layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("layout") = nanobind::none(),
+ nanobind::arg("memory_space") = nanobind::none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a memref type")
+ .def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("layout") = nanobind::none(),
+ nanobind::arg("memory_space") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(), "Create a memref type")
+ .def_prop_ro(
+ "layout",
+ [](PyMemRefType &self)
+ -> nanobind::typed<nanobind::object, PyAttribute> {
+ return PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
+ .maybeDownCast();
+ },
+ "The layout of the MemRef type.")
+ .def(
+ "get_strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+ self, strides.data(), &offset)))
+ throw std::runtime_error(
+ "Failed to extract strides and offset from memref.");
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
+ .def_prop_ro(
+ "affine_map",
+ [](PyMemRefType &self) -> PyAffineMap {
+ MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+ return PyAffineMap(self.getContext(), map);
+ },
+ "The layout of the MemRef type as an affine map.")
+ .def_prop_ro(
+ "memory_space",
+ [](PyMemRefType &self)
+ -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> {
+ MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given MemRef type.");
+}
- MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
- mlirInputs.data(), results.size(),
- mlirResults.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- "Gets a FunctionType from a list of input and result types");
- c.def_static(
- "get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
- DefaultingPyMlirContext context) {
- MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
- // clang-format on
- "Gets a FunctionType from a list of input and result types");
- c.def_prop_ro(
- "inputs",
- [](PyFunctionType &self) {
- MlirType t = self;
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetInput(t, i));
- }
- return types;
- },
- "Returns the list of input types in the FunctionType.");
- c.def_prop_ro(
- "results",
- [](PyFunctionType &self) {
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetResult(self, i));
- }
- return types;
- },
- "Returns the list of result types in the FunctionType.");
- }
-};
+void PyUnrankedMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t =
+ mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"), nanobind::arg("memory_space").none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a unranked memref type")
+ .def_static(
+ "get_unchecked",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"), nanobind::arg("memory_space").none(),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a unranked memref type")
+ .def_prop_ro(
+ "memory_space",
+ [](PyUnrankedMemRefType &self)
+ -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> {
+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given Unranked MemRef type.");
+}
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueTypeGetTypeID;
- static constexpr const char *pyClassName = "OpaqueType";
- using PyConcreteType::PyConcreteType;
+void PyTupleType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_tuple",
+ [](const std::vector<PyType> &elements, DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirElements;
+ mlirElements.reserve(elements.size());
+ for (const auto &element : elements)
+ mlirElements.push_back(element.get());
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ mlirElements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(),
+ "Create a tuple type");
+ c.def_static(
+ "get_tuple",
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(),
+ // clang-format off
+ nanobind::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
+ // clang-format on
+ "Create a tuple type");
+ c.def(
+ "get_type",
+ [](PyTupleType &self,
+ intptr_t pos) -> nanobind::typed<nanobind::object, PyType> {
+ return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+ .maybeDownCast();
+ },
+ nanobind::arg("pos"), "Returns the pos-th type in the tuple type.");
+ c.def_prop_ro(
+ "num_types",
+ [](PyTupleType &self) -> intptr_t {
+ return mlirTupleTypeGetNumTypes(self);
+ },
+ "Returns the number of types contained in a tuple.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const std::string &typeData,
- DefaultingPyMlirContext context) {
- MlirType type = mlirOpaqueTypeGet(context->get(),
- toMlirStringRef(dialectNamespace),
- toMlirStringRef(typeData));
- return PyOpaqueType(context->getRef(), type);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"),
- nb::arg("context") = nb::none(),
- "Create an unregistered (opaque) dialect type.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque type as a string.");
- c.def_prop_ro(
- "data",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaque type as a string.");
- }
-};
+void PyFunctionType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<PyType> inputs, std::vector<PyType> results,
+ DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirInputs;
+ mlirInputs.reserve(inputs.size());
+ for (const auto &input : inputs)
+ mlirInputs.push_back(input.get());
+ std::vector<MlirType> mlirResults;
+ mlirResults.reserve(results.size());
+ for (const auto &result : results)
+ mlirResults.push_back(result.get());
+
+ MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
+ mlirInputs.data(), results.size(),
+ mlirResults.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nanobind::arg("inputs"), nanobind::arg("results"),
+ nanobind::arg("context") = nanobind::none(),
+ "Gets a FunctionType from a list of input and result types");
+ c.def_static(
+ "get",
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nanobind::arg("inputs"), nanobind::arg("results"),
+ nanobind::arg("context") = nanobind::none(),
+ // clang-format off
+ nanobind::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
+ // clang-format on
+ "Gets a FunctionType from a list of input and result types");
+ c.def_prop_ro(
+ "inputs",
+ [](PyFunctionType &self) {
+ MlirType t = self;
+ nanobind::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetInput(t, i));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ c.def_prop_ro(
+ "results",
+ [](PyFunctionType &self) {
+ nanobind::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetResult(self, i));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+}
+
+void PyOpaqueType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const std::string &typeData,
+ DefaultingPyMlirContext context) {
+ MlirType type =
+ mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace),
+ toMlirStringRef(typeData));
+ return PyOpaqueType(context->getRef(), type);
+ },
+ nanobind::arg("dialect_namespace"), nanobind::arg("buffer"),
+ nanobind::arg("context") = nanobind::none(),
+ "Create an unregistered (opaque) dialect type.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+ return nanobind::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque type as a string.");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+ return nanobind::str(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaque type as a string.");
+}
+const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 4a9fb127ee08c..582863ffcbb0d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -535,7 +535,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
IRAffine.cpp
IRAttributes.cpp
IRInterfaces.cpp
- IRTypes.cpp
Pass.cpp
Rewrite.cpp
@@ -846,8 +845,9 @@ declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
- IRCore.cpp
Globals.cpp
+ IRCore.cpp
+ IRTypes.cpp
)
################################################################################
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 43573cbc305fa..a296b5e814b4b 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -15,6 +15,7 @@
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -47,6 +48,49 @@ struct PyTestType
}
};
+struct PyTestIntegerRankedTensorType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
+ PyTestIntegerRankedTensorType,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, unsigned width,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return PyTestIntegerRankedTensorType(
+ ctx->getRef(),
+ mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
+ },
+ nb::arg("shape"), nb::arg("width"),
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+struct PyTestTensorValue
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
+ PyTestTensorValue> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAPythonTestTestTensorValue;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestTensorValue";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+ }
+};
+
class PyTestAttr
: public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
PyTestAttr> {
@@ -73,18 +117,18 @@ class PyTestAttr
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
- [](MlirContext context, bool load) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context,
+ bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
- mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+ mlirDialectHandleRegisterDialect(pythonTestDialect,
+ context.get()->get());
if (load) {
- mlirDialectHandleLoadDialect(pythonTestDialect, context);
+ mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get());
}
},
- nb::arg("context"), nb::arg("load") = true,
- // clang-format off
- nb::sig("def register_python_test_dialect(context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none(), nb::arg("load") = true);
m.def(
"register_dialect",
@@ -100,73 +144,16 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"test_diagnostics_with_errors_and_notes",
- [](MlirContext ctx) {
- mlir::python::CollectDiagnosticsToStringScope handler(ctx);
- mlirPythonTestEmitDiagnosticWithNote(ctx);
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
+ mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
- // clang-format off
- nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none());
PyTestAttr::bind(m);
PyTestType::bind(m);
-
- auto typeCls =
- mlir_type_subclass(m, "TestIntegerRankedTensorType",
- mlirTypeIsARankedIntegerTensor,
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("RankedTensorType"))
- .def_classmethod(
- "get",
- [](const nb::object &cls, std::vector<int64_t> shape,
- unsigned width, MlirContext ctx) {
- MlirAttribute encoding = mlirAttributeGetNull();
- return cls(mlirRankedTensorTypeGet(
- shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
- encoding));
- },
- // clang-format off
- nb::sig("def get(cls: object, shape: collections.abc.Sequence[int], width: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
- nb::arg("context").none() = nb::none());
-
- assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
- "TestIntegerRankedTensorType has no static_typeid");
-
- MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID, nb::arg("replace") = true)(
- nanobind::cpp_function([typeCls](const nb::object &mlirType) {
- return typeCls.get_class()(mlirType);
- }));
-
- auto valueCls = mlir_value_subclass(m, "TestTensorValue",
- mlirTypeIsAPythonTestTestTensorValue)
- .def("is_null", [](MlirValue &self) {
- return mlirValueIsNull(self);
- });
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID)(
- nanobind::cpp_function([valueCls](const nb::object &valueObj) {
- std::optional<nb::object> capsule =
- mlirApiObjectToCapsule(valueObj);
- assert(capsule.has_value() && "capsule is not null");
- MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
- MlirType t = mlirValueGetType(v);
- // This is hyper-specific in order to exercise/test registering a
- // value caster from cpp (but only for a single test case; see
- // testTensorValue python_test.py).
- if (mlirShapedTypeHasStaticShape(t) &&
- mlirShapedTypeGetDimSize(t, 0) == 1 &&
- mlirShapedTypeGetDimSize(t, 1) == 2 &&
- mlirShapedTypeGetDimSize(t, 2) == 3)
- return valueCls.get_class()(valueObj);
- return valueObj;
- }));
+ PyTestIntegerRankedTensorType::bind(m);
+ PyTestTensorValue::bind(m);
}
More information about the llvm-branch-commits
mailing list