[llvm-branch-commits] [mlir] [mlir][py] partially use mlir_type_subclass for IRTypes.cpp (PR #171143)
Oleksandr Alex Zinenko via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Dec 8 06:57:29 PST 2025
https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/171143
Port the bindings for non-shaped builtin types in IRTypes.cpp to use the `mlir_type_subclass` mechanism used by non-builtin types. This is part of a longer-term cleanup to only support one subclassing mechanism. Eventually, the `PyConcreteType` mechanism will be removed.
This required a surgery in the type casters and the `mlir_type_subclass` logic to avoid circular imports of the `_mlir.ir` module that would otherwise when using `mlir_type_subclass` to define classes in the `_mlir.ir` module.
Tests are updated to use the `.get_static_typeid()` function instead of the `.static_typeid` property that was specific to builtin types due to the `PyConcreteType` mechanism. The change should be NFC otherwise.
>From deac26450350ba40b9f9357f68ec3a5e458b43d6 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Mon, 8 Dec 2025 15:50:41 +0100
Subject: [PATCH] [mlir][py] partially use mlir_type_subclass for IRTypes.cpp
Port the bindings for non-shaped builtin types in IRTypes.cpp to use the
`mlir_type_subclass` mechanism used by non-builtin types. This is part of a
longer-term cleanup to only support one subclassing mechanism. Eventually, the
`PyConcreteType` mechanism will be removed.
This required a surgery in the type casters and the `mlir_type_subclass` logic
to avoid circular imports of the `_mlir.ir` module that would otherwise when
using `mlir_type_subclass` to define classes in the `_mlir.ir` module.
Tests are updated to use the `.get_static_typeid()` function instead of the
`.static_typeid` property that was specific to builtin types due to the
`PyConcreteType` mechanism. The change should be NFC otherwise.
---
.../mlir/Bindings/Python/NanobindAdaptors.h | 41 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 1029 ++++++-----------
mlir/lib/Bindings/Python/MainModule.cpp | 15 +
mlir/test/python/dialects/arith_dialect.py | 8 +-
mlir/test/python/ir/builtin_types.py | 11 +-
mlir/test/python/ir/value.py | 6 +-
6 files changed, 425 insertions(+), 685 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 6594670abaaa7..f678f57527e97 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -371,16 +371,22 @@ struct type_caster<MlirTypeID> {
}
return false;
}
- static handle from_cpp(MlirTypeID v, rv_policy,
- cleanup_list *cleanup) noexcept {
+
+ static handle
+ from_cpp_given_module(MlirTypeID v,
+ const nanobind::module_ &module) noexcept {
if (v.ptr == nullptr)
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
- return mlir::python::irModule()
- .attr("TypeID")
+ return module.attr("TypeID")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
+ }
+
+ static handle from_cpp(MlirTypeID v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ return from_cpp_given_module(v, mlir::python::irModule());
};
};
@@ -602,9 +608,12 @@ class mlir_type_subclass : public pure_subclass {
/// Subclasses by looking up the super-class dynamically.
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction,
- GetTypeIDFunctionTy getTypeIDFunction = nullptr)
- : mlir_type_subclass(scope, typeClassName, isaFunction,
- irModule().attr("Type"), getTypeIDFunction) {}
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr,
+ const nanobind::module_ *mlirIrModule = nullptr)
+ : mlir_type_subclass(
+ scope, typeClassName, isaFunction,
+ (mlirIrModule != nullptr ? *mlirIrModule : irModule()).attr("Type"),
+ getTypeIDFunction, mlirIrModule) {}
/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
@@ -613,7 +622,8 @@ class mlir_type_subclass : public pure_subclass {
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction,
const nanobind::object &superCls,
- GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr,
+ const nanobind::module_ *mlirIrModule = nullptr)
: pure_subclass(scope, typeClassName, superCls) {
// Casting constructor. Note that it is hard, if not impossible, to properly
// call chain to parent `__init__` in nanobind due to its special handling
@@ -672,9 +682,18 @@ class mlir_type_subclass : public pure_subclass {
nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID"))
// clang-format on
);
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- getTypeIDFunction())(nanobind::cpp_function(
+
+ // Directly call the caster implementation given the "ir" module,
+ // otherwise it may trigger recursive import as the default caster
+ // attempts to import the "ir" module.
+ MlirTypeID typeID = getTypeIDFunction();
+ mlirIrModule = mlirIrModule ? mlirIrModule : &irModule();
+ nanobind::handle pyTypeID =
+ nanobind::detail::type_caster<MlirTypeID>::from_cpp_given_module(
+ typeID, *mlirIrModule);
+
+ mlirIrModule->attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(pyTypeID)(
+ nanobind::cpp_function(
[thisClass = thisClass](const nanobind::object &mlirType) {
return thisClass(mlirType);
}));
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 34c5b8dd86a66..2e4090c358c47 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -18,13 +18,13 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
using llvm::SmallVector;
-using llvm::Twine;
namespace {
@@ -34,480 +34,368 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerTypeGetTypeID;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create an unsigned integer type");
- c.def_prop_ro(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_prop_ro(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_prop_ro(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_prop_ro(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIndexTypeGetTypeID;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a index type.");
- }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
- static constexpr const char *pyClassName = "FloatType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
- "Returns the width of the floating-point type");
- }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
- : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat4E2M1FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float4E2M1FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
- return PyFloat4E2M1FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
- : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E2M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E2M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
- return PyFloat6E2M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
- : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E3M2FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E3M2FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
- return PyFloat6E3M2FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
- : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
- return PyFloat8E4M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2TypeGet(context->get());
- return PyFloat8E5M2Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3TypeGet(context->get());
- return PyFloat8E4M3Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
- : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
- return PyFloat8E4M3FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
- : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3B11FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
- return PyFloat8E4M3B11FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
- : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
- return PyFloat8E5M2FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E3M4TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E3M4Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E3M4TypeGet(context->get());
- return PyFloat8E3M4Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
- : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E8M0FNUTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E8M0FNUType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
- return PyFloat8E8M0FNUType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
- }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirBFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a bf16 type.");
- }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f16 type.");
- }
-};
-
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatTF32TypeGetTypeID;
- static constexpr const char *pyClassName = "FloatTF32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirTF32TypeGet(context->get());
- return PyTF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a tf32 type.");
- }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat32TypeGetTypeID;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f32 type.");
- }
-};
+static void populateIRTypesModule(const nanobind::module_ &m) {
+ using namespace nanobind_adaptors;
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat64TypeGetTypeID;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f64 type.");
- }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirNoneTypeGetTypeID;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirNoneTypeGet(context->get());
- return PyNoneType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a none type.");
- }
-};
-
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirComplexTypeGetTypeID;
- static constexpr const char *pyClassName = "ComplexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType) {
- // The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType)) {
- MlirType t = mlirComplexTypeGet(elementType);
- return PyComplexType(elementType.getContext(), t);
- }
- throw nb::value_error(
- (Twine("invalid '") +
- nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
- "' and expected floating point or integer type.")
- .str()
- .c_str());
- },
- "Create a complex type");
- c.def_prop_ro(
- "element_type",
- [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
- .maybeDownCast();
- },
- "Returns element type.");
- }
-};
+ mlir_type_subclass integerType(m, "IntegerType", mlirTypeIsAInteger,
+ mlirIntegerTypeGetTypeID, &m);
+ integerType.def_classmethod(
+ "get_signless",
+ [](const nb::object &cls, unsigned width, MlirContext ctx) {
+ return cls(mlirIntegerTypeGet(ctx, width));
+ },
+ nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create a signless integer type");
+ integerType.def_classmethod(
+ "get_signed",
+ [](const nb::object &cls, unsigned width, MlirContext ctx) {
+ return cls(mlirIntegerTypeSignedGet(ctx, width));
+ },
+ nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create a signed integer type");
+ integerType.def_classmethod(
+ "get_unsigned",
+ [](const nb::object &cls, unsigned width, MlirContext ctx) {
+ return cls(mlirIntegerTypeUnsignedGet(ctx, width));
+ },
+ nb::arg("cls"), nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create an unsigned integer type");
+ integerType.def_property_readonly(
+ "width", [](MlirType self) { return mlirIntegerTypeGetWidth(self); },
+ "Returns the width of the integer type");
+ integerType.def_property_readonly(
+ "is_signless",
+ [](MlirType self) { return mlirIntegerTypeIsSignless(self); },
+ "Returns whether this is a signless integer");
+ integerType.def_property_readonly(
+ "is_signed", [](MlirType self) { return mlirIntegerTypeIsSigned(self); },
+ "Returns whether this is a signed integer");
+ integerType.def_property_readonly(
+ "is_unsigned",
+ [](MlirType self) { return mlirIntegerTypeIsUnsigned(self); },
+ "Returns whether this is an unsigned integer");
+
+ // IndexType
+ mlir_type_subclass indexType(m, "IndexType", mlirTypeIsAIndex,
+ mlirIndexTypeGetTypeID, &m);
+
+ indexType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirIndexTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(), "Create a index type.");
+
+ // FloatType (base class for specific float types)
+ mlir_type_subclass floatType(m, "FloatType", mlirTypeIsAFloat, nullptr, &m);
+ floatType.def_property_readonly(
+ "width", [](MlirType self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+
+ // Float4E2M1FNType
+ mlir_type_subclass float4E2M1FNType(
+ m, "Float4E2M1FNType", mlirTypeIsAFloat4E2M1FN, floatType.get_class(),
+ mlirFloat4E2M1FNTypeGetTypeID, &m);
+ float4E2M1FNType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat4E2M1FNTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float4_e2m1fn type.");
+
+ // Float6E2M3FNType
+ mlir_type_subclass float6E2M3FNType(
+ m, "Float6E2M3FNType", mlirTypeIsAFloat6E2M3FN, floatType.get_class(),
+ mlirFloat6E2M3FNTypeGetTypeID, &m);
+ float6E2M3FNType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat6E2M3FNTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float6_e2m3fn type.");
+
+ // Float6E3M2FNType
+ mlir_type_subclass float6E3M2FNType(
+ m, "Float6E3M2FNType", mlirTypeIsAFloat6E3M2FN, floatType.get_class(),
+ mlirFloat6E3M2FNTypeGetTypeID, &m);
+ float6E3M2FNType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat6E3M2FNTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float6_e3m2fn type.");
+
+ // Float8E4M3FNType
+ mlir_type_subclass float8E4M3FNType(
+ m, "Float8E4M3FNType", mlirTypeIsAFloat8E4M3FN, floatType.get_class(),
+ mlirFloat8E4M3FNTypeGetTypeID, &m);
+ float8E4M3FNType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E4M3FNTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e4m3fn type.");
+
+ // Float8E5M2Type
+ mlir_type_subclass float8E5M2Type(m, "Float8E5M2Type", mlirTypeIsAFloat8E5M2,
+ floatType.get_class(),
+ mlirFloat8E5M2TypeGetTypeID, &m);
+ float8E5M2Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E5M2TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e5m2 type.");
+
+ // Float8E4M3Type
+ mlir_type_subclass float8E4M3Type(m, "Float8E4M3Type", mlirTypeIsAFloat8E4M3,
+ floatType.get_class(),
+ mlirFloat8E4M3TypeGetTypeID, &m);
+ float8E4M3Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E4M3TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e4m3 type.");
+
+ // Float8E4M3FNUZType
+ mlir_type_subclass float8E4M3FNUZType(
+ m, "Float8E4M3FNUZType", mlirTypeIsAFloat8E4M3FNUZ, floatType.get_class(),
+ mlirFloat8E4M3FNUZTypeGetTypeID, &m);
+ float8E4M3FNUZType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E4M3FNUZTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e4m3fnuz type.");
+
+ // Float8E4M3B11FNUZType
+ mlir_type_subclass float8E4M3B11FNUZType(
+ m, "Float8E4M3B11FNUZType", mlirTypeIsAFloat8E4M3B11FNUZ,
+ floatType.get_class(), mlirFloat8E4M3B11FNUZTypeGetTypeID, &m);
+ float8E4M3B11FNUZType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E4M3B11FNUZTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e4m3b11fnuz type.");
+
+ // Float8E5M2FNUZType
+ mlir_type_subclass float8E5M2FNUZType(
+ m, "Float8E5M2FNUZType", mlirTypeIsAFloat8E5M2FNUZ, floatType.get_class(),
+ mlirFloat8E5M2FNUZTypeGetTypeID, &m);
+ float8E5M2FNUZType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E5M2FNUZTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e5m2fnuz type.");
+
+ // Float8E3M4Type
+ mlir_type_subclass float8E3M4Type(m, "Float8E3M4Type", mlirTypeIsAFloat8E3M4,
+ floatType.get_class(),
+ mlirFloat8E3M4TypeGetTypeID, &m);
+ float8E3M4Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E3M4TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e3m4 type.");
+
+ // Float8E8M0FNUType
+ mlir_type_subclass float8E8M0FNUType(
+ m, "Float8E8M0FNUType", mlirTypeIsAFloat8E8M0FNU, floatType.get_class(),
+ mlirFloat8E8M0FNUTypeGetTypeID, &m);
+ float8E8M0FNUType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirFloat8E8M0FNUTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(),
+ "Create a float8_e8m0fnu type.");
+
+ // BF16Type
+ mlir_type_subclass bf16Type(m, "BF16Type", mlirTypeIsABF16,
+ floatType.get_class(), mlirBFloat16TypeGetTypeID,
+ &m);
+ bf16Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirBF16TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(), "Create a bf16 type.");
+
+ // F16Type
+ mlir_type_subclass f16Type(m, "F16Type", mlirTypeIsAF16,
+ floatType.get_class(), mlirFloat16TypeGetTypeID,
+ &m);
+ f16Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirF16TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f16 type.");
+
+ // FloatTF32Type
+ mlir_type_subclass tf32Type(m, "FloatTF32Type", mlirTypeIsATF32,
+ floatType.get_class(), mlirFloatTF32TypeGetTypeID,
+ &m);
+ tf32Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirTF32TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(), "Create a tf32 type.");
+
+ // F32Type
+ mlir_type_subclass f32Type(m, "F32Type", mlirTypeIsAF32,
+ floatType.get_class(), mlirFloat32TypeGetTypeID,
+ &m);
+ f32Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirF32TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f32 type.");
+
+ // F64Type
+ mlir_type_subclass f64Type(m, "F64Type", mlirTypeIsAF64,
+ floatType.get_class(), mlirFloat64TypeGetTypeID,
+ &m);
+ f64Type.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirF64TypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(), "Create a f64 type.");
+
+ // NoneType
+ mlir_type_subclass noneType(m, "NoneType", mlirTypeIsANone,
+ mlirNoneTypeGetTypeID, &m);
+ noneType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirNoneTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context") = nb::none(), "Create a none type.");
+
+ // ComplexType
+ mlir_type_subclass complexType(m, "ComplexType", mlirTypeIsAComplex,
+ mlirComplexTypeGetTypeID, &m);
+ complexType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirType elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ return cls(mlirComplexTypeGet(elementType));
+ }
+ throw nb::value_error("Invalid element type for ComplexType: expected "
+ "floating point or integer type.");
+ },
+ "Create a complex type");
+ complexType.def_property_readonly(
+ "element_type",
+ [](MlirType self) { return mlirComplexTypeGetElementType(self); },
+ "Returns element type.");
+
+ // TupleType
+ mlir_type_subclass tupleType(m, "TupleType", mlirTypeIsATuple,
+ mlirTupleTypeGetTypeID, &m);
+ tupleType.def_classmethod(
+ "get_tuple",
+ [](const nb::object &cls, std::vector<MlirType> elements,
+ MlirContext ctx) {
+ return cls(mlirTupleTypeGet(ctx, elements.size(), elements.data()));
+ },
+ nb::arg("cls"), nb::arg("elements"), nb::arg("context") = nb::none(),
+ "Create a tuple type");
+ tupleType.def(
+ "get_type",
+ [](MlirType self, intptr_t pos) {
+ return mlirTupleTypeGetType(self, pos);
+ },
+ nb::arg("pos"), "Returns the pos-th type in the tuple type.");
+ tupleType.def_property_readonly(
+ "num_types", [](MlirType self) { return mlirTupleTypeGetNumTypes(self); },
+ "Returns the number of types contained in a tuple.");
+
+ // FunctionType
+ mlir_type_subclass functionType(m, "FunctionType", mlirTypeIsAFunction,
+ mlirFunctionTypeGetTypeID, &m);
+ functionType.def_classmethod(
+ "get",
+ [](const nb::object &cls, std::vector<MlirType> inputs,
+ std::vector<MlirType> results, MlirContext ctx) {
+ return cls(mlirFunctionTypeGet(ctx, inputs.size(), inputs.data(),
+ results.size(), results.data()));
+ },
+ nb::arg("cls"), nb::arg("inputs"), nb::arg("results"),
+ nb::arg("context") = nb::none(),
+ "Gets a FunctionType from a list of input and result types");
+ functionType.def_property_readonly(
+ "inputs",
+ [](MlirType self) {
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetInput(self, i));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ functionType.def_property_readonly(
+ "results",
+ [](MlirType self) {
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetResult(self, i));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+
+ // OpaqueType
+ mlir_type_subclass opaqueType(m, "OpaqueType", mlirTypeIsAOpaque,
+ mlirOpaqueTypeGetTypeID, &m);
+ opaqueType.def_classmethod(
+ "get",
+ [](const nb::object &cls, const std::string &dialectNamespace,
+ const std::string &typeData, MlirContext ctx) {
+ MlirStringRef dialectNs = mlirStringRefCreate(dialectNamespace.data(),
+ dialectNamespace.size());
+ MlirStringRef data =
+ mlirStringRefCreate(typeData.data(), typeData.size());
+ return cls(mlirOpaqueTypeGet(ctx, dialectNs, data));
+ },
+ nb::arg("cls"), nb::arg("dialect_namespace"), nb::arg("buffer"),
+ nb::arg("context") = nb::none(),
+ "Create an unregistered (opaque) dialect type.");
+ opaqueType.def_property_readonly(
+ "dialect_namespace",
+ [](MlirType self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque type as a string.");
+ opaqueType.def_property_readonly(
+ "data",
+ [](MlirType self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaque type as a string.");
+}
} // namespace
@@ -977,202 +865,17 @@ class PyUnrankedMemRefType
}
};
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTupleTypeGetTypeID;
- static constexpr const char *pyClassName = "TupleType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_tuple",
- [](const std::vector<PyType> &elements,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirElements;
- mlirElements.reserve(elements.size());
- for (const auto &element : elements)
- mlirElements.push_back(element.get());
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- mlirElements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- "Create a tuple type");
- c.def_static(
- "get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- elements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
- // clang-format on
- "Create a tuple type");
- c.def(
- "get_type",
- [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
- .maybeDownCast();
- },
- nb::arg("pos"), "Returns the pos-th type in the tuple type.");
- c.def_prop_ro(
- "num_types",
- [](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self);
- },
- "Returns the number of types contained in a tuple.");
- }
-};
-
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFunctionTypeGetTypeID;
- static constexpr const char *pyClassName = "FunctionType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirInputs;
- mlirInputs.reserve(inputs.size());
- for (const auto &input : inputs)
- mlirInputs.push_back(input.get());
- std::vector<MlirType> mlirResults;
- mlirResults.reserve(results.size());
- for (const auto &result : results)
- mlirResults.push_back(result.get());
-
- MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
- mlirInputs.data(), results.size(),
- mlirResults.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- "Gets a FunctionType from a list of input and result types");
- c.def_static(
- "get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
- DefaultingPyMlirContext context) {
- MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
- // clang-format on
- "Gets a FunctionType from a list of input and result types");
- c.def_prop_ro(
- "inputs",
- [](PyFunctionType &self) {
- MlirType t = self;
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetInput(t, i));
- }
- return types;
- },
- "Returns the list of input types in the FunctionType.");
- c.def_prop_ro(
- "results",
- [](PyFunctionType &self) {
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetResult(self, i));
- }
- return types;
- },
- "Returns the list of result types in the FunctionType.");
- }
-};
-
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueTypeGetTypeID;
- static constexpr const char *pyClassName = "OpaqueType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const std::string &typeData,
- DefaultingPyMlirContext context) {
- MlirType type = mlirOpaqueTypeGet(context->get(),
- toMlirStringRef(dialectNamespace),
- toMlirStringRef(typeData));
- return PyOpaqueType(context->getRef(), type);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"),
- nb::arg("context") = nb::none(),
- "Create an unregistered (opaque) dialect type.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque type as a string.");
- c.def_prop_ro(
- "data",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaque type as a string.");
- }
-};
-
} // namespace
void mlir::python::populateIRTypes(nb::module_ &m) {
- PyIntegerType::bind(m);
- PyFloatType::bind(m);
- PyIndexType::bind(m);
- PyFloat4E2M1FNType::bind(m);
- PyFloat6E2M3FNType::bind(m);
- PyFloat6E3M2FNType::bind(m);
- PyFloat8E4M3FNType::bind(m);
- PyFloat8E5M2Type::bind(m);
- PyFloat8E4M3Type::bind(m);
- PyFloat8E4M3FNUZType::bind(m);
- PyFloat8E4M3B11FNUZType::bind(m);
- PyFloat8E5M2FNUZType::bind(m);
- PyFloat8E3M4Type::bind(m);
- PyFloat8E8M0FNUType::bind(m);
- PyBF16Type::bind(m);
- PyF16Type::bind(m);
- PyTF32Type::bind(m);
- PyF32Type::bind(m);
- PyF64Type::bind(m);
- PyNoneType::bind(m);
- PyComplexType::bind(m);
+ // Populate types using mlir_type_subclass
+ populateIRTypesModule(m);
+
+ // Keep PyShapedType and its subclasses that weren't replaced
PyShapedType::bind(m);
PyVectorType::bind(m);
PyRankedTensorType::bind(m);
PyUnrankedTensorType::bind(m);
PyMemRefType::bind(m);
PyUnrankedMemRefType::bind(m);
- PyTupleType::bind(m);
- PyFunctionType::bind(m);
- PyOpaqueType::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index ba767ad6692cf..fb73beda4cf88 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -145,6 +145,21 @@ NB_MODULE(_mlir, m) {
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
+ irModule.def(
+ MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ return nb::cpp_function([mlirTypeID, replace](
+ nb::callable typeCaster) -> nb::object {
+ PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
+ return typeCaster;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
+ "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
+ // clang-format on
+ "typeid"_a, nb::kw_only(), "replace"_a = false,
+ "Register a type caster for casting MLIR types to custom user types.");
populateIRCore(irModule);
populateIRAffine(irModule);
populateIRAttributes(irModule);
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db8..ad318238b77c6 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -54,10 +54,10 @@ def _binary_op(lhs, rhs, op: str) -> "ArithValue":
op = getattr(arith, f"{op}Op")
return op(lhs, rhs).result
- @register_value_caster(F16Type.static_typeid)
- @register_value_caster(F32Type.static_typeid)
- @register_value_caster(F64Type.static_typeid)
- @register_value_caster(IntegerType.static_typeid)
+ @register_value_caster(F16Type.get_static_typeid())
+ @register_value_caster(F32Type.get_static_typeid())
+ @register_value_caster(F64Type.get_static_typeid())
+ @register_value_caster(IntegerType.get_static_typeid())
class ArithValue(Value):
def __init__(self, v):
super().__init__(v)
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 54863253fc770..20509050eda9f 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -185,7 +185,7 @@ def testStandardTypeCasts():
try:
tillegal = IntegerType(Type.parse("f32", ctx))
except ValueError as e:
- # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
+ # CHECK: ValueError: Cannot cast type to IntegerType (from F32Type(f32))
print("ValueError:", e)
else:
print("Exception not produced")
@@ -302,7 +302,7 @@ def testComplexType():
try:
complex_invalid = ComplexType.get(index)
except ValueError as e:
- # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+ # CHECK: Invalid element type for ComplexType: expected floating point or integer type.
print(e)
else:
print("Exception not produced")
@@ -714,7 +714,8 @@ def testTypeIDs():
# mlirTypeGetTypeID(self) for an instance.
# CHECK: all equal
for t1, t2 in types:
- tid1, tid2 = t1.static_typeid, Type(t2).typeid
+ # TODO: remove the alternative once mlir_type_subclass transition is complete.
+ tid1, tid2 = t1.static_typeid if hasattr(t1, "static_typeid") else t1.get_static_typeid(), Type(t2).typeid
assert tid1 == tid2 and hash(tid1) == hash(
tid2
), f"expected hash and value equality {t1} {t2}"
@@ -728,7 +729,9 @@ def testTypeIDs():
# CHECK: all equal
for t1, t2 in typeid_dict.items():
- assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash(
+ # TODO: remove the alternative once mlir_type_subclass transition is complete.
+ tid1 = t1.static_typeid if hasattr(t1, "static_typeid") else t1.get_static_typeid()
+ assert tid1 == t2.typeid and hash(tid1) == hash(
t2.typeid
), f"expected hash and value equality {t1} {t2}"
else:
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 4a241afb8e89d..9d9b6c2090974 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -361,7 +361,7 @@ def __init__(self, v):
def __str__(self):
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
- @register_value_caster(IntegerType.static_typeid)
+ @register_value_caster(IntegerType.get_static_typeid())
def cast_int(v) -> Value:
print("in caster", v.__class__.__name__)
if isinstance(v, OpResult):
@@ -425,7 +425,7 @@ def reduction(arg0, arg1):
try:
- @register_value_caster(IntegerType.static_typeid)
+ @register_value_caster(IntegerType.get_static_typeid())
def dont_cast_int_shouldnt_register(v):
...
@@ -433,7 +433,7 @@ def dont_cast_int_shouldnt_register(v):
# CHECK: Value caster is already registered: {{.*}}cast_int
print(e)
- @register_value_caster(IntegerType.static_typeid, replace=True)
+ @register_value_caster(IntegerType.get_static_typeid(), replace=True)
def dont_cast_int(v) -> OpResult:
assert isinstance(v, OpResult)
print("don't cast", v.result_number, v)
More information about the llvm-branch-commits
mailing list