[llvm-branch-commits] [mlir] [mlir][Python] move IRTypes and IRAttributes to public headers (PR #174118)
Maksim Levental via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 31 14:21:04 PST 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/174118
None
>From 9e1cebcb757e4be2e07deeb7dbbbe0af71e5b593 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 14:20:39 -0800
Subject: [PATCH] [mlir][Python] move IRTypes and IRAttributes to public
headers
---
.../mlir/Bindings/Python/IRAttributes.h | 593 +++++
mlir/include/mlir/Bindings/Python/IRCore.h | 17 +-
mlir/include/mlir/Bindings/Python/IRTypes.h | 393 ++-
mlir/lib/Bindings/Python/IRAttributes.cpp | 2297 +++++++----------
mlir/lib/Bindings/Python/IRTypes.cpp | 1610 +++++-------
mlir/lib/Bindings/Python/MainModule.cpp | 2 +
mlir/python/CMakeLists.txt | 6 +-
.../python/lib/PythonTestModuleNanobind.cpp | 131 +-
8 files changed, 2632 insertions(+), 2417 deletions(-)
create mode 100644 mlir/include/mlir/Bindings/Python/IRAttributes.h
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
new file mode 100644
index 0000000000000..d64e32037664c
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -0,0 +1,593 @@
+//===- IRAttributes.h - Exports builtin and standard attributes -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+#define MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+
+#include <optional>
+#include <string>
+#include <string_view>
+#include <utility>
+
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+struct nb_buffer_info {
+ void *ptr = nullptr;
+ ssize_t itemsize = 0;
+ ssize_t size = 0;
+ const char *format = nullptr;
+ ssize_t ndim = 0;
+ SmallVector<ssize_t, 4> shape;
+ SmallVector<ssize_t, 4> strides;
+ bool readonly = false;
+
+ nb_buffer_info(
+ void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+ SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+ bool readonly = false,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr));
+
+ explicit nb_buffer_info(Py_buffer *view)
+ : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+ {view->shape, view->shape + view->ndim},
+ // TODO(phawkins): check for null strides
+ {view->strides, view->strides + view->ndim},
+ view->readonly != 0,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
+ view, PyBuffer_Release)) {}
+
+ nb_buffer_info(const nb_buffer_info &) = delete;
+ nb_buffer_info(nb_buffer_info &&) = default;
+ nb_buffer_info &operator=(const nb_buffer_info &) = delete;
+ nb_buffer_info &operator=(nb_buffer_info &&) = default;
+
+private:
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
+};
+
+class MLIR_PYTHON_API_EXPORTED nb_buffer : public nanobind::object {
+ NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
+
+ nb_buffer_info request() const;
+};
+
+template <typename T>
+struct nb_format_descriptor {};
+
+class MLIR_PYTHON_API_EXPORTED PyAffineMapAttribute
+ : public PyConcreteAttribute<PyAffineMapAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
+ static constexpr const char *pyClassName = "AffineMapAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAffineMapAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerSetAttribute
+ : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+ static constexpr const char *pyClassName = "IntegerSetAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerSetAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+template <typename T>
+static T pyTryCast(nanobind::handle object) {
+ try {
+ return nanobind::cast<T>(object);
+ } catch (nanobind::cast_error &err) {
+ std::string msg = std::string("Invalid attribute when attempting to "
+ "create an ArrayAttribute (") +
+ err.what() + ")";
+ throw std::runtime_error(msg.c_str());
+ } catch (std::runtime_error &err) {
+ std::string msg = std::string("Invalid attribute (None?) when attempting "
+ "to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw std::runtime_error(msg.c_str());
+ }
+}
+
+/// A python-wrapped dense array attribute with an element type and a derived
+/// implementation class.
+template <typename EltTy, typename DerivedT>
+class MLIR_PYTHON_API_EXPORTED PyDenseArrayAttribute
+ : public PyConcreteAttribute<DerivedT> {
+public:
+ using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
+
+ /// Iterator over the integer elements of a dense array.
+ class PyDenseArrayIterator {
+ public:
+ PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ /// Return a copy of the iterator.
+ PyDenseArrayIterator dunderIter() { return *this; }
+
+ /// Return the next element.
+ EltTy dunderNext() {
+ // Throw if the index has reached the end.
+ if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
+ throw nanobind::stop_iteration();
+ return DerivedT::getElement(attr.get(), nextIndex++);
+ }
+
+ /// Bind the iterator class.
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
+ .def("__iter__", &PyDenseArrayIterator::dunderIter)
+ .def("__next__", &PyDenseArrayIterator::dunderNext);
+ }
+
+ private:
+ /// The referenced dense array attribute.
+ PyAttribute attr;
+ /// The next index to read.
+ int nextIndex = 0;
+ };
+
+ /// Get the element at the given index.
+ EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
+
+ /// Bind the attribute class.
+ static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
+ // Bind the constructor.
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ c.def_static(
+ "get",
+ [](const nanobind::sequence &py_values, DefaultingPyMlirContext ctx) {
+ std::vector<bool> values;
+ for (nanobind::handle py_value : py_values) {
+ int is_true = PyObject_IsTrue(py_value.ptr());
+ if (is_true < 0) {
+ throw nanobind::python_error();
+ }
+ values.push_back(is_true);
+ }
+ return getAttribute(values, ctx->getRef());
+ },
+ nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued dense array attribute");
+ } else {
+ c.def_static(
+ "get",
+ [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
+ return getAttribute(values, ctx->getRef());
+ },
+ nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued dense array attribute");
+ }
+ // Bind the array methods.
+ c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
+ if (i >= mlirDenseArrayGetNumElements(arr))
+ throw nanobind::index_error("DenseArray index out of range");
+ return arr.getItem(i);
+ });
+ c.def("__len__", [](const DerivedT &arr) {
+ return mlirDenseArrayGetNumElements(arr);
+ });
+ c.def("__iter__",
+ [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
+ c.def("__add__", [](DerivedT &arr, const nanobind::list &extras) {
+ std::vector<EltTy> values;
+ intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
+ values.reserve(numOldElements + nanobind::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ values.push_back(arr.getItem(i));
+ for (nanobind::handle attr : extras)
+ values.push_back(pyTryCast<EltTy>(attr));
+ return getAttribute(values, arr.getContext());
+ });
+ }
+
+private:
+ static DerivedT getAttribute(const std::vector<EltTy> &values,
+ PyMlirContextRef ctx) {
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ std::vector<int> intValues(values.begin(), values.end());
+ MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
+ intValues.data());
+ return DerivedT(ctx, attr);
+ } else {
+ MlirAttribute attr =
+ DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+ return DerivedT(ctx, attr);
+ }
+ }
+};
+
+/// Instantiate the python dense array classes.
+struct PyDenseBoolArrayAttribute
+ : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
+ static constexpr auto getAttribute = mlirDenseBoolArrayGet;
+ static constexpr auto getElement = mlirDenseBoolArrayGetElement;
+ static constexpr const char *pyClassName = "DenseBoolArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI8ArrayAttribute
+ : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
+ static constexpr auto getAttribute = mlirDenseI8ArrayGet;
+ static constexpr auto getElement = mlirDenseI8ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI8ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI16ArrayAttribute
+ : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
+ static constexpr auto getAttribute = mlirDenseI16ArrayGet;
+ static constexpr auto getElement = mlirDenseI16ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI16ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI32ArrayAttribute
+ : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
+ static constexpr auto getAttribute = mlirDenseI32ArrayGet;
+ static constexpr auto getElement = mlirDenseI32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI64ArrayAttribute
+ : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
+ static constexpr auto getAttribute = mlirDenseI64ArrayGet;
+ static constexpr auto getElement = mlirDenseI64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF32ArrayAttribute
+ : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
+ static constexpr auto getAttribute = mlirDenseF32ArrayGet;
+ static constexpr auto getElement = mlirDenseF32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF64ArrayAttribute
+ : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
+ static constexpr auto getAttribute = mlirDenseF64ArrayGet;
+ static constexpr auto getElement = mlirDenseF64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyArrayAttribute
+ : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+ static constexpr const char *pyClassName = "ArrayAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirArrayAttrGetTypeID;
+
+ class PyArrayAttributeIterator {
+ public:
+ PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ PyArrayAttributeIterator &dunderIter() { return *this; }
+
+ nanobind::typed<nanobind::object, PyAttribute> dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+ private:
+ PyAttribute attr;
+ int nextIndex = 0;
+ };
+
+ MlirAttribute getItem(intptr_t i) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Float Point Attribute subclass - FloatAttr.
+class MLIR_PYTHON_API_EXPORTED PyFloatAttribute
+ : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+ static constexpr const char *pyClassName = "FloatAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class MLIR_PYTHON_API_EXPORTED PyIntegerAttribute
+ : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static int64_t toPyInt(PyIntegerAttribute &self);
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class MLIR_PYTHON_API_EXPORTED PyBoolAttribute
+ : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+ static constexpr const char *pyClassName = "BoolAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PySymbolRefAttribute
+ : public PyConcreteAttribute<PySymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
+ static constexpr const char *pyClassName = "SymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFlatSymbolRefAttribute
+ : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+ static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyOpaqueAttribute
+ : public PyConcreteAttribute<PyOpaqueAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
+ static constexpr const char *pyClassName = "OpaqueAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+// TODO: Support construction of string elements.
+class MLIR_PYTHON_API_EXPORTED PyDenseElementsAttribute
+ : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+ static constexpr const char *pyClassName = "DenseElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseElementsAttribute
+ getFromList(const nanobind::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute
+ getFromBuffer(const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr);
+
+ intptr_t dunderLen() const;
+
+ std::unique_ptr<nb_buffer_info> accessBuffer();
+
+ static void bindDerived(ClassTy &c);
+
+ static PyType_Slot slots[];
+
+private:
+ static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
+ static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+
+ static bool isUnsignedIntegerFormat(std::string_view format);
+
+ static bool isSignedIntegerFormat(std::string_view format);
+
+ static MlirType
+ getShapedType(std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ Py_buffer &view);
+
+ static MlirAttribute getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context);
+
+ // There is a complication for boolean numpy arrays, as numpy represents
+ // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
+ // booleans per byte.
+ static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context);
+
+ // This does the opposite transformation of
+ // `getBitpackedAttributeFromBooleanBuffer`
+ std::unique_ptr<nb_buffer_info>
+ getBooleanBufferFromBitpackedAttribute() const;
+
+ template <typename Type>
+ std::unique_ptr<nb_buffer_info>
+ bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ // Prepare the data for the buffer_info.
+ // Buffer is configured for read-only access below.
+ Type *data = static_cast<Type *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ // Prepare the shape for the buffer_info.
+ SmallVector<intptr_t, 4> shape;
+ for (intptr_t i = 0; i < rank; ++i)
+ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+ // Prepare the strides for the buffer_info.
+ SmallVector<intptr_t, 4> strides;
+ if (mlirDenseElementsAttrIsSplat(*this)) {
+ // Splats are special, only the single value is stored.
+ strides.assign(rank, 0);
+ } else {
+ for (intptr_t i = 1; i < rank; ++i) {
+ intptr_t strideFactor = 1;
+ for (intptr_t j = i; j < rank; ++j)
+ strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ strides.push_back(sizeof(Type) * strideFactor);
+ }
+ strides.push_back(sizeof(Type));
+ }
+ const char *format;
+ if (explicitFormat) {
+ format = explicitFormat;
+ } else {
+ format = nb_format_descriptor<Type>::format();
+ }
+ return std::make_unique<nb_buffer_info>(
+ data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
+ /*readonly=*/true);
+ }
+};
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseIntElementsAttribute
+ : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+ static constexpr const char *pyClassName = "DenseIntElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
+ nanobind::int_ dunderGetItem(intptr_t pos) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDenseResourceElementsAttribute
+ : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsADenseResourceElements;
+ static constexpr const char *pyClassName = "DenseResourceElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseResourceElementsAttribute
+ getFromBuffer(const nb_buffer &buffer, const std::string &name,
+ const PyType &type, std::optional<size_t> alignment,
+ bool isMutable, DefaultingPyMlirContext contextWrapper);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDictAttribute
+ : public PyConcreteAttribute<PyDictAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+ static constexpr const char *pyClassName = "DictAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirDictionaryAttrGetTypeID;
+
+ intptr_t dunderLen() const;
+
+ bool dunderContains(const std::string &name) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseFPElementsAttribute
+ : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+ static constexpr const char *pyClassName = "DenseFPElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ nanobind::float_ dunderGetItem(intptr_t pos) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyTypeAttribute
+ : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+ static constexpr const char *pyClassName = "TypeAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTypeAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unit Attribute subclass. Unit attributes don't have values.
+class MLIR_PYTHON_API_EXPORTED PyUnitAttribute
+ : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+ static constexpr const char *pyClassName = "UnitAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnitAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Strided layout attribute subclass.
+class MLIR_PYTHON_API_EXPORTED PyStridedLayoutAttribute
+ : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+ static constexpr const char *pyClassName = "StridedLayoutAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStridedLayoutAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+MLIR_PYTHON_API_EXPORTED void populateIRAttributes(nanobind::module_ &m);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index af6c8dbbb7fa8..1e435a1d442d4 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -989,7 +989,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
nanobind::cast<nanobind::callable>(nanobind::cpp_function(
- [](PyType pyType) -> DerivedTy { return pyType; })));
+ [](PyType pyType) -> DerivedTy { return pyType; })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1133,7 +1134,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
nanobind::cast<nanobind::callable>(
nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
- })));
+ })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1517,6 +1519,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
// and redefine bindDerived.
using ClassTy = nanobind::class_<DerivedTy, PyValue>;
using IsAFunctionTy = bool (*)(MlirValue);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef, MlirValue value)
@@ -1559,6 +1563,15 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
return self.maybeDownCast();
});
+
+ if (DerivedTy::getTypeIdFunction) {
+ PyGlobals::get().registerValueCaster(
+ DerivedTy::getTypeIdFunction(),
+ nanobind::cast<nanobind::callable>(nanobind::cpp_function(
+ [](PyValue pyValue) -> DerivedTy { return pyValue; })),
+ /*replace*/ true);
+ }
+
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 87e0e10764bd8..a0901fefec5ce 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,13 +9,284 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir-c/BuiltinTypes.h"
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+MLIR_PYTHON_API_EXPORTED int mlirTypeIsAIntegerOrFloat(MlirType type);
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerType
+ : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerTypeGetTypeID;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Index Type subclass - IndexType.
+class MLIR_PYTHON_API_EXPORTED PyIndexType
+ : public PyConcreteType<PyIndexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIndexTypeGetTypeID;
+ static constexpr const char *pyClassName = "IndexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFloatType
+ : public PyConcreteType<PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+ static constexpr const char *pyClassName = "FloatType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float4E2M1FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat4E2M1FNType
+ : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat4E2M1FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float4E2M1FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E2M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E2M3FNType
+ : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E2M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E2M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E3M2FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E3M2FNType
+ : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E3M2FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E3M2FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNType
+ : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2Type
+ : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3Type
+ : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNUZType
+ : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3B11FNUZType
+ : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3B11FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2FNUZType
+ : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E3M4Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E3M4Type
+ : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E3M4TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E3M4Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E8M0FNUType
+ : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E8M0FNUTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E8M0FNUType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - BF16Type.
+class MLIR_PYTHON_API_EXPORTED PyBF16Type
+ : public PyConcreteType<PyBF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirBFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "BF16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F16Type.
+class MLIR_PYTHON_API_EXPORTED PyF16Type
+ : public PyConcreteType<PyF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "F16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - TF32Type.
+class MLIR_PYTHON_API_EXPORTED PyTF32Type
+ : public PyConcreteType<PyTF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatTF32TypeGetTypeID;
+ static constexpr const char *pyClassName = "FloatTF32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F32Type.
+class MLIR_PYTHON_API_EXPORTED PyF32Type
+ : public PyConcreteType<PyF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat32TypeGetTypeID;
+ static constexpr const char *pyClassName = "F32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F64Type.
+class MLIR_PYTHON_API_EXPORTED PyF64Type
+ : public PyConcreteType<PyF64Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat64TypeGetTypeID;
+ static constexpr const char *pyClassName = "F64Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// None Type subclass - NoneType.
+class MLIR_PYTHON_API_EXPORTED PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirNoneTypeGetTypeID;
+ static constexpr const char *pyClassName = "NoneType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Complex Type subclass - ComplexType.
+class MLIR_PYTHON_API_EXPORTED PyComplexType
+ : public PyConcreteType<PyComplexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirComplexTypeGetTypeID;
+ static constexpr const char *pyClassName = "ComplexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
/// Shaped Type Interface - ShapedType
-class MLIR_PYTHON_API_EXPORTED PyShapedType
+class MLIR_PYTHON_API_EXPORTED MLIR_PYTHON_API_EXPORTED PyShapedType
: public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
@@ -27,6 +298,124 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
private:
void requireHasRank();
};
+
+/// Vector Type subclass - VectorType.
+class MLIR_PYTHON_API_EXPORTED PyVectorType
+ : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirVectorTypeGetTypeID;
+ static constexpr const char *pyClassName = "VectorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc);
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context);
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyRankedTensorType
+ : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "RankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedTensorType
+ : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class MLIR_PYTHON_API_EXPORTED PyMemRefType
+ : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "MemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedMemRefType
+ : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedMemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Tuple Type subclass - TupleType.
+class MLIR_PYTHON_API_EXPORTED PyTupleType
+ : public PyConcreteType<PyTupleType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTupleTypeGetTypeID;
+ static constexpr const char *pyClassName = "TupleType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Function type.
+class MLIR_PYTHON_API_EXPORTED PyFunctionType
+ : public PyConcreteType<PyFunctionType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFunctionTypeGetTypeID;
+ static constexpr const char *pyClassName = "FunctionType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Opaque Type subclass - OpaqueType.
+class MLIR_PYTHON_API_EXPORTED PyOpaqueType
+ : public PyConcreteType<PyOpaqueType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueTypeGetTypeID;
+ static constexpr const char *pyClassName = "OpaqueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+MLIR_PYTHON_API_EXPORTED void populateIRTypes(nanobind::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index f0f0ae9ba741e..3cd3ce5c4c0ee 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -14,6 +14,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
@@ -125,65 +126,29 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-struct nb_buffer_info {
- void *ptr = nullptr;
- ssize_t itemsize = 0;
- ssize_t size = 0;
- const char *format = nullptr;
- ssize_t ndim = 0;
- SmallVector<ssize_t, 4> shape;
- SmallVector<ssize_t, 4> strides;
- bool readonly = false;
-
- nb_buffer_info(
- void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
- SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
- bool readonly = false,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
- : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
- shape(std::move(shape_in)), strides(std::move(strides_in)),
- readonly(readonly), owned_view(std::move(owned_view_in)) {
- size = 1;
- for (ssize_t i = 0; i < ndim; ++i) {
- size *= shape[i];
- }
+nb_buffer_info::nb_buffer_info(
+ void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+ SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+ bool readonly,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in)
+ : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
+ shape(std::move(shape_in)), strides(std::move(strides_in)),
+ readonly(readonly), owned_view(std::move(owned_view_in)) {
+ size = 1;
+ for (ssize_t i = 0; i < ndim; ++i) {
+ size *= shape[i];
}
+}
- explicit nb_buffer_info(Py_buffer *view)
- : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
- {view->shape, view->shape + view->ndim},
- // TODO(phawkins): check for null strides
- {view->strides, view->strides + view->ndim},
- view->readonly != 0,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
- view, PyBuffer_Release)) {}
-
- nb_buffer_info(const nb_buffer_info &) = delete;
- nb_buffer_info(nb_buffer_info &&) = default;
- nb_buffer_info &operator=(const nb_buffer_info &) = delete;
- nb_buffer_info &operator=(nb_buffer_info &&) = default;
-
-private:
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
-};
-
-class nb_buffer : public nb::object {
- NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
-
- nb_buffer_info request() const {
- int flags = PyBUF_STRIDES | PyBUF_FORMAT;
- auto *view = new Py_buffer();
- if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
- delete view;
- throw nb::python_error();
- }
- return nb_buffer_info(view);
+nb_buffer_info nb_buffer::request() const {
+ int flags = PyBUF_STRIDES | PyBUF_FORMAT;
+ auto *view = new Py_buffer();
+ if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
+ delete view;
+ throw nb::python_error();
}
-};
-
-template <typename T>
-struct nb_format_descriptor {};
+ return nb_buffer_info(view);
+}
template <>
struct nb_format_descriptor<bool> {
@@ -230,1052 +195,719 @@ struct nb_format_descriptor<double> {
static const char *format() { return "d"; }
};
-class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
- static constexpr const char *pyClassName = "AffineMapAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirAffineMapAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyAffineMap &affineMap) {
- MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
- return PyAffineMapAttribute(affineMap.getContext(), attr);
- },
- nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
- c.def_prop_ro(
- "value",
- [](PyAffineMapAttribute &self) {
- return PyAffineMap(self.getContext(),
- mlirAffineMapAttrGetValue(self));
- },
- "Returns the value of the AffineMap attribute");
- }
-};
+void PyAffineMapAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyAffineMap &affineMap) {
+ MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+ return PyAffineMapAttribute(affineMap.getContext(), attr);
+ },
+ nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+ c.def_prop_ro(
+ "value",
+ [](PyAffineMapAttribute &self) {
+ return PyAffineMap(self.getContext(), mlirAffineMapAttrGetValue(self));
+ },
+ "Returns the value of the AffineMap attribute");
+}
-class PyIntegerSetAttribute
- : public PyConcreteAttribute<PyIntegerSetAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
- static constexpr const char *pyClassName = "IntegerSetAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerSetAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyIntegerSet &integerSet) {
- MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
- return PyIntegerSetAttribute(integerSet.getContext(), attr);
- },
- nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
- }
-};
+void PyIntegerSetAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyIntegerSet &integerSet) {
+ MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+ return PyIntegerSetAttribute(integerSet.getContext(), attr);
+ },
+ nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+}
-template <typename T>
-static T pyTryCast(nb::handle object) {
- try {
- return nb::cast<T>(object);
- } catch (nb::cast_error &err) {
- std::string msg = std::string("Invalid attribute when attempting to "
- "create an ArrayAttribute (") +
- err.what() + ")";
- throw std::runtime_error(msg.c_str());
- } catch (std::runtime_error &err) {
- std::string msg = std::string("Invalid attribute (None?) when attempting "
- "to create an ArrayAttribute (") +
- err.what() + ")";
- throw std::runtime_error(msg.c_str());
- }
+nb::typed<nb::object, PyAttribute>
+PyArrayAttribute::PyArrayAttributeIterator::dunderNext() {
+ // TODO: Throw is an inefficient way to stop iteration.
+ if (PyArrayAttribute::PyArrayAttributeIterator::nextIndex >=
+ mlirArrayAttrGetNumElements(
+ PyArrayAttribute::PyArrayAttributeIterator::attr.get()))
+ throw nb::stop_iteration();
+ return PyAttribute(
+ this->PyArrayAttribute::PyArrayAttributeIterator::attr
+ .getContext(),
+ mlirArrayAttrGetElement(
+ PyArrayAttribute::PyArrayAttributeIterator::attr.get(),
+ PyArrayAttribute::PyArrayAttributeIterator::nextIndex++))
+ .maybeDownCast();
}
-/// A python-wrapped dense array attribute with an element type and a derived
-/// implementation class.
-template <typename EltTy, typename DerivedT>
-class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
-public:
- using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
-
- /// Iterator over the integer elements of a dense array.
- class PyDenseArrayIterator {
- public:
- PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- /// Return a copy of the iterator.
- PyDenseArrayIterator dunderIter() { return *this; }
-
- /// Return the next element.
- EltTy dunderNext() {
- // Throw if the index has reached the end.
- if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return DerivedT::getElement(attr.get(), nextIndex++);
- }
+void PyArrayAttribute::PyArrayAttributeIterator::bind(nb::module_ &m) {
+ nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+ .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+ .def("__next__", &PyArrayAttributeIterator::dunderNext);
+}
- /// Bind the iterator class.
- static void bind(nb::module_ &m) {
- nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
- .def("__iter__", &PyDenseArrayIterator::dunderIter)
- .def("__next__", &PyDenseArrayIterator::dunderNext);
- }
+MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
+ return mlirArrayAttrGetElement(*this, i);
+}
- private:
- /// The referenced dense array attribute.
- PyAttribute attr;
- /// The next index to read.
- int nextIndex = 0;
- };
+void PyArrayAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const nb::list &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(nb::len(attributes));
+ for (auto attribute : attributes) {
+ mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
+ }
+ MlirAttribute attr = mlirArrayAttrGet(
+ context->get(), mlirAttributes.size(), mlirAttributes.data());
+ return PyArrayAttribute(context->getRef(), attr);
+ },
+ nb::arg("attributes"), nb::arg("context") = nb::none(),
+ "Gets a uniqued Array attribute");
+ c.def("__getitem__",
+ [](PyArrayAttribute &arr,
+ intptr_t i) -> nb::typed<nb::object, PyAttribute> {
+ if (i >= mlirArrayAttrGetNumElements(arr))
+ throw nb::index_error("ArrayAttribute index out of range");
+ return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
+ })
+ .def("__len__",
+ [](const PyArrayAttribute &arr) {
+ return mlirArrayAttrGetNumElements(arr);
+ })
+ .def("__iter__", [](const PyArrayAttribute &arr) {
+ return PyArrayAttributeIterator(arr);
+ });
+ c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
+ std::vector<MlirAttribute> attributes;
+ intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+ attributes.reserve(numOldElements + nb::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ attributes.push_back(arr.getItem(i));
+ for (nb::handle attr : extras)
+ attributes.push_back(pyTryCast<PyAttribute>(attr));
+ MlirAttribute arrayAttr = mlirArrayAttrGet(
+ arr.getContext()->get(), attributes.size(), attributes.data());
+ return PyArrayAttribute(arr.getContext(), arrayAttr);
+ });
+}
+void PyFloatAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, double value, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &type, double value, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_f32",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF32TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a f32 type");
+ c.def_static(
+ "get_f64",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF64TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a f64 type");
+ c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
+ "Returns the value of the float attribute");
+ c.def("__float__", mlirFloatAttrGetValueDouble,
+ "Converts the value of the float attribute to a Python float");
+}
- /// Get the element at the given index.
- EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
-
- /// Bind the attribute class.
- static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
- // Bind the constructor.
- if constexpr (std::is_same_v<EltTy, bool>) {
- c.def_static(
- "get",
- [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
- std::vector<bool> values;
- for (nb::handle py_value : py_values) {
- int is_true = PyObject_IsTrue(py_value.ptr());
- if (is_true < 0) {
- throw nb::python_error();
- }
- values.push_back(is_true);
- }
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context") = nb::none(),
- "Gets a uniqued dense array attribute");
- } else {
- c.def_static(
- "get",
- [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context") = nb::none(),
- "Gets a uniqued dense array attribute");
- }
- // Bind the array methods.
- c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
- if (i >= mlirDenseArrayGetNumElements(arr))
- throw nb::index_error("DenseArray index out of range");
- return arr.getItem(i);
- });
- c.def("__len__", [](const DerivedT &arr) {
- return mlirDenseArrayGetNumElements(arr);
- });
- c.def("__iter__",
- [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
- c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
- std::vector<EltTy> values;
- intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
- values.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- values.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- values.push_back(pyTryCast<EltTy>(attr));
- return getAttribute(values, arr.getContext());
- });
- }
+void PyIntegerAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, int64_t value) {
+ MlirAttribute attr = mlirIntegerAttrGet(type, value);
+ return PyIntegerAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"),
+ "Gets an uniqued integer attribute associated to a type");
+ c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute");
+ c.def("__int__", toPyInt,
+ "Converts the value of the integer attribute to a Python int");
+ c.def_prop_ro_static(
+ "static_typeid",
+ [](nb::object & /*class*/) {
+ return PyTypeID(mlirIntegerAttrGetTypeID());
+ },
+ nb::sig("def static_typeid(/) -> TypeID"));
+}
-private:
- static DerivedT getAttribute(const std::vector<EltTy> &values,
- PyMlirContextRef ctx) {
- if constexpr (std::is_same_v<EltTy, bool>) {
- std::vector<int> intValues(values.begin(), values.end());
- MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
- intValues.data());
- return DerivedT(ctx, attr);
- } else {
- MlirAttribute attr =
- DerivedT::getAttribute(ctx->get(), values.size(), values.data());
- return DerivedT(ctx, attr);
- }
- }
-};
+int64_t PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
+ MlirType type = mlirAttributeGetType(self);
+ if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+ return mlirIntegerAttrGetValueInt(self);
+ if (mlirIntegerTypeIsSigned(type))
+ return mlirIntegerAttrGetValueSInt(self);
+ return mlirIntegerAttrGetValueUInt(self);
+}
-/// Instantiate the python dense array classes.
-struct PyDenseBoolArrayAttribute
- : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
- static constexpr auto getAttribute = mlirDenseBoolArrayGet;
- static constexpr auto getElement = mlirDenseBoolArrayGetElement;
- static constexpr const char *pyClassName = "DenseBoolArrayAttr";
- static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI8ArrayAttribute
- : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
- static constexpr auto getAttribute = mlirDenseI8ArrayGet;
- static constexpr auto getElement = mlirDenseI8ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI8ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI16ArrayAttribute
- : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
- static constexpr auto getAttribute = mlirDenseI16ArrayGet;
- static constexpr auto getElement = mlirDenseI16ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI16ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI32ArrayAttribute
- : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
- static constexpr auto getAttribute = mlirDenseI32ArrayGet;
- static constexpr auto getElement = mlirDenseI32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI64ArrayAttribute
- : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
- static constexpr auto getAttribute = mlirDenseI64ArrayGet;
- static constexpr auto getElement = mlirDenseI64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF32ArrayAttribute
- : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
- static constexpr auto getAttribute = mlirDenseF32ArrayGet;
- static constexpr auto getElement = mlirDenseF32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF64ArrayAttribute
- : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
- static constexpr auto getAttribute = mlirDenseF64ArrayGet;
- static constexpr auto getElement = mlirDenseF64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
+void PyBoolAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](bool value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+ return PyBoolAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued bool attribute");
+ c.def_prop_ro("value", mlirBoolAttrGetValue,
+ "Returns the value of the bool attribute");
+ c.def("__bool__", mlirBoolAttrGetValue,
+ "Converts the value of the bool attribute to a Python bool");
+}
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
- static constexpr const char *pyClassName = "ArrayAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirArrayAttrGetTypeID;
-
- class PyArrayAttributeIterator {
- public:
- PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- PyArrayAttributeIterator &dunderIter() { return *this; }
-
- nb::typed<nb::object, PyAttribute> dunderNext() {
- // TODO: Throw is an inefficient way to stop iteration.
- if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return PyAttribute(this->attr.getContext(),
- mlirArrayAttrGetElement(attr.get(), nextIndex++))
- .maybeDownCast();
- }
+PySymbolRefAttribute
+PySymbolRefAttribute::fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context) {
+ if (symbols.empty())
+ throw std::runtime_error("SymbolRefAttr must be composed of at least "
+ "one symbol.");
+ MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
+ SmallVector<MlirAttribute, 3> referenceAttrs;
+ for (size_t i = 1; i < symbols.size(); ++i) {
+ referenceAttrs.push_back(
+ mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
+ }
+ return PySymbolRefAttribute(context.getRef(),
+ mlirSymbolRefAttrGet(context.get(), rootSymbol,
+ referenceAttrs.size(),
+ referenceAttrs.data()));
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
- .def("__iter__", &PyArrayAttributeIterator::dunderIter)
- .def("__next__", &PyArrayAttributeIterator::dunderNext);
- }
+void PySymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::vector<std::string> &symbols,
+ DefaultingPyMlirContext context) {
+ return PySymbolRefAttribute::fromList(symbols, context.resolve());
+ },
+ nb::arg("symbols"), nb::arg("context") = nb::none(),
+ "Gets a uniqued SymbolRef attribute from a list of symbol names");
+ c.def_prop_ro(
+ "value",
+ [](PySymbolRefAttribute &self) {
+ std::vector<std::string> symbols = {
+ unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
+ for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); ++i)
+ symbols.push_back(
+ unwrap(mlirSymbolRefAttrGetRootReference(
+ mlirSymbolRefAttrGetNestedReference(self, i)))
+ .str());
+ return symbols;
+ },
+ "Returns the value of the SymbolRef attribute as a list[str]");
+}
- private:
- PyAttribute attr;
- int nextIndex = 0;
- };
+void PyFlatSymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+ return PyFlatSymbolRefAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets a uniqued FlatSymbolRef attribute");
+ c.def_prop_ro(
+ "value",
+ [](PyFlatSymbolRefAttribute &self) {
+ MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the FlatSymbolRef attribute as a string");
+}
- MlirAttribute getItem(intptr_t i) {
- return mlirArrayAttrGetElement(*this, i);
- }
+void PyOpaqueAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const nb_buffer &buffer,
+ PyType &type, DefaultingPyMlirContext context) {
+ const nb_buffer_info bufferInfo = buffer.request();
+ intptr_t bufferSize = bufferInfo.size;
+ MlirAttribute attr = mlirOpaqueAttrGet(
+ context->get(), toMlirStringRef(dialectNamespace), bufferSize,
+ static_cast<char *>(bufferInfo.ptr), type);
+ return PyOpaqueAttribute(context->getRef(), attr);
+ },
+ nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
+ nb::arg("context") = nb::none(),
+ // clang-format off
+ nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
+ // clang-format on
+ "Gets an Opaque attribute.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque attribute as a string");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaqued attributes as `bytes`");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const nb::list &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(nb::len(attributes));
- for (auto attribute : attributes) {
- mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
- }
- MlirAttribute attr = mlirArrayAttrGet(
- context->get(), mlirAttributes.size(), mlirAttributes.data());
- return PyArrayAttribute(context->getRef(), attr);
- },
- nb::arg("attributes"), nb::arg("context") = nb::none(),
- "Gets a uniqued Array attribute");
- c.def(
- "__getitem__",
- [](PyArrayAttribute &arr,
- intptr_t i) -> nb::typed<nb::object, PyAttribute> {
- if (i >= mlirArrayAttrGetNumElements(arr))
- throw nb::index_error("ArrayAttribute index out of range");
- return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
- })
- .def("__len__",
- [](const PyArrayAttribute &arr) {
- return mlirArrayAttrGetNumElements(arr);
- })
- .def("__iter__", [](const PyArrayAttribute &arr) {
- return PyArrayAttributeIterator(arr);
- });
- c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
- std::vector<MlirAttribute> attributes;
- intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
- attributes.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- attributes.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- attributes.push_back(pyTryCast<PyAttribute>(attr));
- MlirAttribute arrayAttr = mlirArrayAttrGet(
- arr.getContext()->get(), attributes.size(), attributes.data());
- return PyArrayAttribute(arr.getContext(), arrayAttr);
- });
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getFromList(const nb::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+ const size_t numAttributes = nb::len(attributes);
+ if (numAttributes == 0)
+ throw nb::value_error("Attributes list must be non-empty.");
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Expected a static ShapedType for the shaped_type parameter: "
+ << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
+ throw nb::value_error(message.c_str());
+ }
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
}
-};
-/// Float Point Attribute subclass - FloatAttr.
-class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
- static constexpr const char *pyClassName = "FloatAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, double value, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Invalid attribute", errors.take());
- return PyFloatAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_unchecked",
- [](PyType &type, double value, DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute attr =
- mlirFloatAttrDoubleGet(context.get()->get(), type, value);
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Invalid attribute", errors.take());
- return PyFloatAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_f32",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF32TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a f32 type");
- c.def_static(
- "get_f64",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF64TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a f64 type");
- c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
- "Returns the value of the float attribute");
- c.def("__float__", mlirFloatAttrGetValueDouble,
- "Converts the value of the float attribute to a Python float");
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(numAttributes);
+ for (const nb::handle &attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "All attributes must be of the same type and match "
+ << "the type parameter: expected="
+ << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
+ << ", but got=" << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
+ throw nb::value_error(message.c_str());
+ }
}
-};
-/// Integer Attribute subclass - IntegerAttr.
-class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
- static constexpr const char *pyClassName = "IntegerAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type, value);
- return PyIntegerAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"),
- "Gets an uniqued integer attribute associated to a type");
- c.def_prop_ro("value", toPyInt,
- "Returns the value of the integer attribute");
- c.def("__int__", toPyInt,
- "Converts the value of the integer attribute to a Python int");
- c.def_prop_ro_static(
- "static_typeid",
- [](nb::object & /*class*/) {
- return PyTypeID(mlirIntegerAttrGetTypeID());
- },
- nanobind::sig("def static_typeid(/) -> TypeID"));
- }
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
-private:
- static int64_t toPyInt(PyIntegerAttribute &self) {
- MlirType type = mlirAttributeGetType(self);
- if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
- return mlirIntegerAttrGetValueInt(self);
- if (mlirIntegerTypeIsSigned(type))
- return mlirIntegerAttrGetValueSInt(self);
- return mlirIntegerAttrGetValueUInt(self);
- }
-};
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
-/// Bool Attribute subclass - BoolAttr.
-class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
- static constexpr const char *pyClassName = "BoolAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](bool value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
- return PyBoolAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued bool attribute");
- c.def_prop_ro("value", mlirBoolAttrGetValue,
- "Returns the value of the bool attribute");
- c.def("__bool__", mlirBoolAttrGetValue,
- "Converts the value of the bool attribute to a Python bool");
+PyDenseElementsAttribute PyDenseElementsAttribute::getFromBuffer(
+ const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper) {
+ // Request a contiguous view. In exotic cases, this will cause a copy.
+ int flags = PyBUF_ND;
+ if (!explicitType) {
+ flags |= PyBUF_FORMAT;
}
-};
-
-class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
- static constexpr const char *pyClassName = "SymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
- PyMlirContext &context) {
- if (symbols.empty())
- throw std::runtime_error("SymbolRefAttr must be composed of at least "
- "one symbol.");
- MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
- SmallVector<MlirAttribute, 3> referenceAttrs;
- for (size_t i = 1; i < symbols.size(); ++i) {
- referenceAttrs.push_back(
- mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
- }
- return PySymbolRefAttribute(context.getRef(),
- mlirSymbolRefAttrGet(context.get(), rootSymbol,
- referenceAttrs.size(),
- referenceAttrs.data()));
+ Py_buffer view;
+ if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
+ throw nb::python_error();
}
+ auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::vector<std::string> &symbols,
- DefaultingPyMlirContext context) {
- return PySymbolRefAttribute::fromList(symbols, context.resolve());
- },
- nb::arg("symbols"), nb::arg("context") = nb::none(),
- "Gets a uniqued SymbolRef attribute from a list of symbol names");
- c.def_prop_ro(
- "value",
- [](PySymbolRefAttribute &self) {
- std::vector<std::string> symbols = {
- unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
- for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
- ++i)
- symbols.push_back(
- unwrap(mlirSymbolRefAttrGetRootReference(
- mlirSymbolRefAttrGetNestedReference(self, i)))
- .str());
- return symbols;
- },
- "Returns the value of the SymbolRef attribute as a list[str]");
+ MlirContext context = contextWrapper->get();
+ MlirAttribute attr = getAttributeFromBuffer(
+ view, signless, explicitType, std::move(explicitShape), context);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
}
-};
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+}
-class PyFlatSymbolRefAttribute
- : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
- static constexpr const char *pyClassName = "FlatSymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
- return PyFlatSymbolRefAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued FlatSymbolRef attribute");
- c.def_prop_ro(
- "value",
- [](PyFlatSymbolRefAttribute &self) {
- MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the FlatSymbolRef attribute as a string");
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr) {
+ auto contextWrapper =
+ PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+ if (!mlirAttributeIsAInteger(elementAttr) &&
+ !mlirAttributeIsAFloat(elementAttr)) {
+ std::string message = "Illegal element type for DenseElementsAttr: ";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
}
-};
-
-class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
- static constexpr const char *pyClassName = "OpaqueAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const nb_buffer &buffer,
- PyType &type, DefaultingPyMlirContext context) {
- const nb_buffer_info bufferInfo = buffer.request();
- intptr_t bufferSize = bufferInfo.size;
- MlirAttribute attr = mlirOpaqueAttrGet(
- context->get(), toMlirStringRef(dialectNamespace), bufferSize,
- static_cast<char *>(bufferInfo.ptr), type);
- return PyOpaqueAttribute(context->getRef(), attr);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
- nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
- // clang-format on
- "Gets an Opaque attribute.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque attribute as a string");
- c.def_prop_ro(
- "data",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
- return nb::bytes(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaqued attributes as `bytes`");
+ if (!mlirTypeIsAShaped(shapedType) ||
+ !mlirShapedTypeHasStaticShape(shapedType)) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+ throw nb::value_error(message.c_str());
+ }
+ MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+ MlirType attrType = mlirAttributeGetType(elementAttr);
+ if (!mlirTypeEqual(shapedElementType, attrType)) {
+ std::string message =
+ "Shaped element type and attribute type must be equal: shaped=";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+ message.append(", element=");
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
}
-};
-// TODO: Support construction of string elements.
-class PyDenseElementsAttribute
- : public PyConcreteAttribute<PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
- static constexpr const char *pyClassName = "DenseElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseElementsAttribute
- getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
- DefaultingPyMlirContext contextWrapper) {
- const size_t numAttributes = nb::len(attributes);
- if (numAttributes == 0)
- throw nb::value_error("Attributes list must be non-empty.");
-
- MlirType shapedType;
- if (explicitType) {
- if ((!mlirTypeIsAShaped(*explicitType) ||
- !mlirShapedTypeHasStaticShape(*explicitType))) {
-
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "Expected a static ShapedType for the shaped_type parameter: "
- << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
- throw nb::value_error(message.c_str());
- }
- shapedType = *explicitType;
- } else {
- SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
- shapedType = mlirRankedTensorTypeGet(
- shape.size(), shape.data(),
- mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
- mlirAttributeGetNull());
- }
+ MlirAttribute elements =
+ mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(numAttributes);
- for (const nb::handle &attribute : attributes) {
- MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
- MlirType attrType = mlirAttributeGetType(mlirAttribute);
- mlirAttributes.push_back(mlirAttribute);
-
- if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "All attributes must be of the same type and match "
- << "the type parameter: expected="
- << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
- << ", but got="
- << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
- throw nb::value_error(message.c_str());
- }
- }
+intptr_t PyDenseElementsAttribute::dunderLen() const {
+ return mlirElementsAttrGetNumElements(*this);
+}
- MlirAttribute elements = mlirDenseElementsAttrGet(
- shapedType, mlirAttributes.size(), mlirAttributes.data());
+std::unique_ptr<nb_buffer_info> PyDenseElementsAttribute::accessBuffer() {
+ MlirType shapedType = mlirAttributeGetType(*this);
+ MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+ std::string format;
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ if (mlirTypeIsAF32(elementType)) {
+ // f32
+ return bufferInfo<float>(shapedType);
}
-
- static PyDenseElementsAttribute
- getFromBuffer(const nb_buffer &array, bool signless,
- const std::optional<PyType> &explicitType,
- std::optional<std::vector<int64_t>> explicitShape,
- DefaultingPyMlirContext contextWrapper) {
- // Request a contiguous view. In exotic cases, this will cause a copy.
- int flags = PyBUF_ND;
- if (!explicitType) {
- flags |= PyBUF_FORMAT;
- }
- Py_buffer view;
- if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
- throw nb::python_error();
- }
- auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
-
- MlirContext context = contextWrapper->get();
- MlirAttribute attr = getAttributeFromBuffer(
- view, signless, explicitType, std::move(explicitShape), context);
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseElementsAttr could not be constructed from the given buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+ if (mlirTypeIsAF64(elementType)) {
+ // f64
+ return bufferInfo<double>(shapedType);
}
-
- static PyDenseElementsAttribute getSplat(const PyType &shapedType,
- PyAttribute &elementAttr) {
- auto contextWrapper =
- PyMlirContext::forContext(mlirTypeGetContext(shapedType));
- if (!mlirAttributeIsAInteger(elementAttr) &&
- !mlirAttributeIsAFloat(elementAttr)) {
- std::string message = "Illegal element type for DenseElementsAttr: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
- throw nb::value_error(message.c_str());
+ if (mlirTypeIsAF16(elementType)) {
+ // f16
+ return bufferInfo<uint16_t>(shapedType, "e");
+ }
+ if (mlirTypeIsAIndex(elementType)) {
+ // Same as IndexType::kInternalStorageBitWidth
+ return bufferInfo<int64_t>(shapedType);
+ }
+ if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 32) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i32
+ return bufferInfo<int32_t>(shapedType);
}
- if (!mlirTypeIsAShaped(shapedType) ||
- !mlirShapedTypeHasStaticShape(shapedType)) {
- std::string message =
- "Expected a static ShapedType for the shaped_type parameter: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
- throw nb::value_error(message.c_str());
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i32
+ return bufferInfo<uint32_t>(shapedType);
}
- MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
- MlirType attrType = mlirAttributeGetType(elementAttr);
- if (!mlirTypeEqual(shapedElementType, attrType)) {
- std::string message =
- "Shaped element type and attribute type must be equal: shaped=";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
- message.append(", element=");
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
- throw nb::value_error(message.c_str());
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 64) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i64
+ return bufferInfo<int64_t>(shapedType);
}
-
- MlirAttribute elements =
- mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
- }
-
- intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
-
- std::unique_ptr<nb_buffer_info> accessBuffer() {
- MlirType shapedType = mlirAttributeGetType(*this);
- MlirType elementType = mlirShapedTypeGetElementType(shapedType);
- std::string format;
-
- if (mlirTypeIsAF32(elementType)) {
- // f32
- return bufferInfo<float>(shapedType);
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i64
+ return bufferInfo<uint64_t>(shapedType);
}
- if (mlirTypeIsAF64(elementType)) {
- // f64
- return bufferInfo<double>(shapedType);
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 8) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i8
+ return bufferInfo<int8_t>(shapedType);
}
- if (mlirTypeIsAF16(elementType)) {
- // f16
- return bufferInfo<uint16_t>(shapedType, "e");
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i8
+ return bufferInfo<uint8_t>(shapedType);
}
- if (mlirTypeIsAIndex(elementType)) {
- // Same as IndexType::kInternalStorageBitWidth
- return bufferInfo<int64_t>(shapedType);
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 16) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i16
+ return bufferInfo<int16_t>(shapedType);
}
- if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 32) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i32
- return bufferInfo<int32_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i32
- return bufferInfo<uint32_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 64) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i64
- return bufferInfo<int64_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i64
- return bufferInfo<uint64_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 8) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i8
- return bufferInfo<int8_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i8
- return bufferInfo<uint8_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 16) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i16
- return bufferInfo<int16_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i16
- return bufferInfo<uint16_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 1) {
- // i1 / bool
- // We can not send the buffer directly back to Python, because the i1
- // values are bitpacked within MLIR. We call numpy's unpackbits function
- // to convert the bytes.
- return getBooleanBufferFromBitpackedAttribute();
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i16
+ return bufferInfo<uint16_t>(shapedType);
}
-
- // TODO: Currently crashes the program.
- // Reported as https://github.com/pybind/pybind11/issues/3336
- throw std::invalid_argument(
- "unsupported data type for conversion to Python buffer");
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 1) {
+ // i1 / bool
+ // We can not send the buffer directly back to Python, because the i1
+ // values are bitpacked within MLIR. We call numpy's unpackbits function
+ // to convert the bytes.
+ return getBooleanBufferFromBitpackedAttribute();
}
- static void bindDerived(ClassTy &c) {
+ // TODO: Currently crashes the program.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ throw std::invalid_argument(
+ "unsupported data type for conversion to Python buffer");
+}
+
+void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
#if PY_VERSION_HEX < 0x03090000
- PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
- tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
- tp->tp_as_buffer->bf_releasebuffer =
- PyDenseElementsAttribute::bf_releasebuffer;
+ PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
+ tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
+ tp->tp_as_buffer->bf_releasebuffer =
+ PyDenseElementsAttribute::bf_releasebuffer;
#endif
- c.def("__len__", &PyDenseElementsAttribute::dunderLen)
- .def_static(
- "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
- nb::arg("signless") = true, nb::arg("type") = nb::none(),
- nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
- // clang-format off
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+ .def_static(
+ "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
+ nb::arg("signless") = true, nb::arg("type") = nb::none(),
+ nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
+ // clang-format off
nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
- // clang-format on
- kDenseElementsAttrGetDocstring)
- .def_static("get", PyDenseElementsAttribute::getFromList,
- nb::arg("attrs"), nb::arg("type") = nb::none(),
- nb::arg("context") = nb::none(),
- kDenseElementsAttrGetFromListDocstring)
- .def_static("get_splat", PyDenseElementsAttribute::getSplat,
- nb::arg("shaped_type"), nb::arg("element_attr"),
- "Gets a DenseElementsAttr where all values are the same")
- .def_prop_ro("is_splat",
- [](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self);
- })
- .def("get_splat_value",
- [](PyDenseElementsAttribute &self)
- -> nb::typed<nb::object, PyAttribute> {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw nb::value_error(
- "get_splat_value called on a non-splat attribute");
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self))
- .maybeDownCast();
- });
- }
-
- static PyType_Slot slots[];
+ // clang-format on
+ kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ nb::arg("attrs"), nb::arg("type") = nb::none(),
+ nb::arg("context") = nb::none(),
+ kDenseElementsAttrGetFromListDocstring)
+ .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+ nb::arg("shaped_type"), nb::arg("element_attr"),
+ "Gets a DenseElementsAttr where all values are the same")
+ .def_prop_ro("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self);
+ })
+ .def("get_splat_value",
+ [](PyDenseElementsAttribute &self)
+ -> nb::typed<nb::object, PyAttribute> {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nb::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self))
+ .maybeDownCast();
+ });
+}
-private:
- static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
- static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+bool PyDenseElementsAttribute::isUnsignedIntegerFormat(
+ std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+ code == 'Q';
+}
- static bool isUnsignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
- code == 'Q';
- }
+bool PyDenseElementsAttribute::isSignedIntegerFormat(std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+ code == 'q';
+}
- static bool isSignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
- code == 'q';
+MlirType PyDenseElementsAttribute::getShapedType(
+ std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape, Py_buffer &view) {
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(view.shape, view.shape + view.ndim);
}
- static MlirType
- getShapedType(std::optional<MlirType> bulkLoadElementType,
- std::optional<std::vector<int64_t>> explicitShape,
- Py_buffer &view) {
- SmallVector<int64_t> shape;
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
- shape.append(explicitShape->begin(), explicitShape->end());
- } else {
- shape.append(view.shape, view.shape + view.ndim);
- }
-
- if (mlirTypeIsAShaped(*bulkLoadElementType)) {
- if (explicitShape) {
- throw std::invalid_argument("Shape can only be specified explicitly "
- "when the type is not a shaped type.");
- }
- return *bulkLoadElementType;
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
}
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- return mlirRankedTensorTypeGet(shape.size(), shape.data(),
- *bulkLoadElementType, encodingAttr);
+ return *bulkLoadElementType;
}
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+}
- static MlirAttribute getAttributeFromBuffer(
- Py_buffer &view, bool signless, std::optional<PyType> explicitType,
- const std::optional<std::vector<int64_t>> &explicitShape,
- MlirContext &context) {
- // Detect format codes that are suitable for bulk loading. This includes
- // all byte aligned integer and floating point types up to 8 bytes.
- // Notably, this excludes exotics types which do not have a direct
- // representation in the buffer protocol (i.e. complex, etc).
- std::optional<MlirType> bulkLoadElementType;
- if (explicitType) {
- bulkLoadElementType = *explicitType;
- } else {
- std::string_view format(view.format);
- if (format == "f") {
- // f32
- assert(view.itemsize == 4 && "mismatched array itemsize");
- bulkLoadElementType = mlirF32TypeGet(context);
- } else if (format == "d") {
- // f64
- assert(view.itemsize == 8 && "mismatched array itemsize");
- bulkLoadElementType = mlirF64TypeGet(context);
- } else if (format == "e") {
- // f16
- assert(view.itemsize == 2 && "mismatched array itemsize");
- bulkLoadElementType = mlirF16TypeGet(context);
- } else if (format == "?") {
- // i1
- // The i1 type needs to be bit-packed, so we will handle it separately
- return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
- context);
- } else if (isSignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeSignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeSignedGet(context, 16);
- }
- } else if (isUnsignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // unsigned i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // unsigned i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeUnsignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeUnsignedGet(context, 16);
- }
+MlirAttribute PyDenseElementsAttribute::getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context) {
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes exotics types which do not have a direct
+ // representation in the buffer protocol (i.e. complex, etc).
+ std::optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
+ // f32
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
+ // f64
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
+ // f16
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (format == "?") {
+ // i1
+ // The i1 type needs to be bit-packed, so we will handle it separately
+ return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+ context);
+ } else if (isSignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // i32
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // i64
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
}
- if (!bulkLoadElementType) {
- throw std::invalid_argument(
- std::string("unimplemented array format conversion from format: ") +
- std::string(format));
+ } else if (isUnsignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // unsigned i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // unsigned i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
}
-
- MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
- return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
- }
-
- // There is a complication for boolean numpy arrays, as numpy represents
- // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
- // booleans per byte.
- static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
- Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
- MlirContext &context) {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a bit-packed MLIR attribute is "
- "unsupported on big-endian systems");
+ if (!bulkLoadElementType) {
+ throw std::invalid_argument(
+ std::string("unimplemented array format conversion from format: ") +
+ std::string(format));
}
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
- /*data=*/static_cast<uint8_t *>(view.buf),
- /*shape=*/{static_cast<size_t>(view.len)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object packbitsFunc = numpy.attr("packbits");
- nb::object packedBooleans =
- packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
- nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
-
- MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
- std::move(explicitShape), view);
- assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
- // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
- // packedBooleans, hence the MlirAttribute will remain valid even when
- // packedBooleans get reclaimed by the end of the function.
- return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
- pythonBuffer.ptr);
}
- // This does the opposite transformation of
- // `getBitpackedAttributeFromBooleanBuffer`
- std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a numpy array from a MLIR attribute "
- "is unsupported on big-endian systems");
- }
+ MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+ return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+}
- int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
- int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
- uint8_t *bitpackedData = static_cast<uint8_t *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
- /*data=*/bitpackedData,
- /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object unpackbitsFunc = numpy.attr("unpackbits");
- nb::object equalFunc = numpy.attr("equal");
- nb::object reshapeFunc = numpy.attr("reshape");
- nb::object unpackedBooleans =
- unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
-
- // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
- // We need to:
- // 1. Slice away the padded bits
- // 2. Make the boolean array have the correct shape
- // 3. Convert the array to a boolean array
- unpackedBooleans = unpackedBooleans[nb::slice(
- nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
- unpackedBooleans = equalFunc(unpackedBooleans, 1);
-
- MlirType shapedType = mlirAttributeGetType(*this);
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- std::vector<intptr_t> shape(rank);
- for (intptr_t i = 0; i < rank; ++i) {
- shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
- }
- unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+MlirAttribute PyDenseElementsAttribute::getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context) {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a bit-packed MLIR attribute is "
+ "unsupported on big-endian systems");
+ }
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
+ /*data=*/static_cast<uint8_t *>(view.buf),
+ /*shape=*/{static_cast<size_t>(view.len)});
+
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object packbitsFunc = numpy.attr("packbits");
+ nb::object packedBooleans =
+ packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
+ nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
+
+ MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
+ std::move(explicitShape), view);
+ assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
+ // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
+ // packedBooleans, hence the MlirAttribute will remain valid even when
+ // packedBooleans get reclaimed by the end of the function.
+ return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+ pythonBuffer.ptr);
+}
- // Make sure the returned nb::buffer_view claims ownership of the data in
- // `pythonBuffer` so it remains valid when Python reads it
- nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
- return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+std::unique_ptr<nb_buffer_info>
+PyDenseElementsAttribute::getBooleanBufferFromBitpackedAttribute() const {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a numpy array from a MLIR attribute "
+ "is unsupported on big-endian systems");
}
- template <typename Type>
- std::unique_ptr<nb_buffer_info>
- bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- // Prepare the data for the buffer_info.
- // Buffer is configured for read-only access below.
- Type *data = static_cast<Type *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- // Prepare the shape for the buffer_info.
- SmallVector<intptr_t, 4> shape;
- for (intptr_t i = 0; i < rank; ++i)
- shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
- // Prepare the strides for the buffer_info.
- SmallVector<intptr_t, 4> strides;
- if (mlirDenseElementsAttrIsSplat(*this)) {
- // Splats are special, only the single value is stored.
- strides.assign(rank, 0);
- } else {
- for (intptr_t i = 1; i < rank; ++i) {
- intptr_t strideFactor = 1;
- for (intptr_t j = i; j < rank; ++j)
- strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
- strides.push_back(sizeof(Type) * strideFactor);
- }
- strides.push_back(sizeof(Type));
- }
- const char *format;
- if (explicitFormat) {
- format = explicitFormat;
- } else {
- format = nb_format_descriptor<Type>::format();
- }
- return std::make_unique<nb_buffer_info>(
- data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
- /*readonly=*/true);
+ int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+ int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
+ uint8_t *bitpackedData = static_cast<uint8_t *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
+ /*data=*/bitpackedData,
+ /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
+
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object unpackbitsFunc = numpy.attr("unpackbits");
+ nb::object equalFunc = numpy.attr("equal");
+ nb::object reshapeFunc = numpy.attr("reshape");
+ nb::object unpackedBooleans =
+ unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
+
+ // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
+ // We need to:
+ // 1. Slice away the padded bits
+ // 2. Make the boolean array have the correct shape
+ // 3. Convert the array to a boolean array
+ unpackedBooleans = unpackedBooleans[nb::slice(
+ nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
+ unpackedBooleans = equalFunc(unpackedBooleans, 1);
+
+ MlirType shapedType = mlirAttributeGetType(*this);
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ std::vector<intptr_t> shape(rank);
+ for (intptr_t i = 0; i < rank; ++i) {
+ shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
}
-}; // namespace
+ unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+
+ // Make sure the returned nb::buffer_view claims ownership of the data
+ // in `pythonBuffer` so it remains valid when Python reads it
+ nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
+ return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+}
PyType_Slot PyDenseElementsAttribute::slots[] = {
// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
@@ -1333,364 +965,294 @@ PyType_Slot PyDenseElementsAttribute::slots[] = {
delete reinterpret_cast<nb_buffer_info *>(view->internal);
}
-/// Refinement of the PyDenseElementsAttribute for attributes containing
-/// integer (and boolean) values. Supports element access.
-class PyDenseIntElementsAttribute
- : public PyConcreteAttribute<PyDenseIntElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
- static constexpr const char *pyClassName = "DenseIntElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- /// Returns the element at the given linear position. Asserts if the index
- /// is out of range.
- nb::int_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
- }
+nb::int_ PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) const {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds element");
+ }
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Index type can also appear as a DenseIntElementsAttr and therefore can be
- // casted to integer.
- assert(mlirTypeIsAInteger(type) ||
- mlirTypeIsAIndex(type) && "expected integer/index element type in "
- "dense int elements attribute");
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::int_ is implicitly constructible
- // from any C++ integral type and handles bitwidth correctly.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAIndex(type)) {
- return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Index type can also appear as a DenseIntElementsAttr and therefore can be
+ // casted to integer.
+ assert(mlirTypeIsAInteger(type) ||
+ mlirTypeIsAIndex(type) && "expected integer/index element type in "
+ "dense int elements attribute");
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nb::int_ is implicitly
+ // constructible from any C++ integral type and handles bitwidth correctly.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAIndex(type)) {
+ return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+ }
+ unsigned width = mlirIntegerTypeGetWidth(type);
+ bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+ if (isUnsigned) {
+ if (width == 1) {
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
- unsigned width = mlirIntegerTypeGetWidth(type);
- bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
- if (isUnsigned) {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
- }
- } else {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
- }
+ if (width == 8) {
+ return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
+ }
+ if (width == 16) {
+ return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
+ }
+ if (width == 32) {
+ return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
+ }
+ if (width == 64) {
+ return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
+ }
+ } else {
+ if (width == 1) {
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
+ }
+ if (width == 8) {
+ return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
+ }
+ if (width == 16) {
+ return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
+ }
+ if (width == 32) {
+ return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
+ }
+ if (width == 64) {
+ return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
}
- throw nb::type_error("Unsupported integer type");
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
}
-};
+ throw nb::type_error("Unsupported integer type");
+}
+void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+}
// Check if the python version is less than 3.13. Py_IsFinalizing is a part
// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
#if PY_VERSION_HEX < 0x030d0000
#define Py_IsFinalizing _Py_IsFinalizing
#endif
-class PyDenseResourceElementsAttribute
- : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction =
- mlirAttributeIsADenseResourceElements;
- static constexpr const char *pyClassName = "DenseResourceElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseResourceElementsAttribute
- getFromBuffer(const nb_buffer &buffer, const std::string &name,
- const PyType &type, std::optional<size_t> alignment,
- bool isMutable, DefaultingPyMlirContext contextWrapper) {
- if (!mlirTypeIsAShaped(type)) {
- throw std::invalid_argument(
- "Constructing a DenseResourceElementsAttr requires a ShapedType.");
- }
-
- // Do not request any conversions as we must ensure to use caller
- // managed memory.
- int flags = PyBUF_STRIDES;
- std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
- if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
- throw nb::python_error();
- }
+PyDenseResourceElementsAttribute
+PyDenseResourceElementsAttribute::getFromBuffer(
+ const nb_buffer &buffer, const std::string &name, const PyType &type,
+ std::optional<size_t> alignment, bool isMutable,
+ DefaultingPyMlirContext contextWrapper) {
+ if (!mlirTypeIsAShaped(type)) {
+ throw std::invalid_argument(
+ "Constructing a DenseResourceElementsAttr requires a ShapedType.");
+ }
- // This scope releaser will only release if we haven't yet transferred
- // ownership.
- auto freeBuffer = llvm::make_scope_exit([&]() {
- if (view)
- PyBuffer_Release(view.get());
- });
+ // Do not request any conversions as we must ensure to use caller
+ // managed memory.
+ int flags = PyBUF_STRIDES;
+ std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
+ if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
+ throw nb::python_error();
+ }
- if (!PyBuffer_IsContiguous(view.get(), 'A')) {
- throw std::invalid_argument("Contiguous buffer is required.");
- }
+ // This scope releaser will only release if we haven't yet transferred
+ // ownership.
+ auto freeBuffer = llvm::make_scope_exit([&]() {
+ if (view)
+ PyBuffer_Release(view.get());
+ });
- // Infer alignment to be the stride of one element if not explicit.
- size_t inferredAlignment;
- if (alignment)
- inferredAlignment = *alignment;
- else
- inferredAlignment = view->strides[view->ndim - 1];
-
- // The userData is a Py_buffer* that the deleter owns.
- auto deleter = [](void *userData, const void *data, size_t size,
- size_t align) {
- if (Py_IsFinalizing())
- return;
- assert(Py_IsInitialized() && "expected interpreter to be initialized");
- Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
- nb::gil_scoped_acquire gil;
- PyBuffer_Release(ownedView);
- delete ownedView;
- };
-
- size_t rawBufferSize = view->len;
- MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
- type, toMlirStringRef(name), view->buf, rawBufferSize,
- inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseResourceElementsAttr could not be constructed from the given "
- "buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- view.release();
- return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+ if (!PyBuffer_IsContiguous(view.get(), 'A')) {
+ throw std::invalid_argument("Contiguous buffer is required.");
}
- static void bindDerived(ClassTy &c) {
- c.def_static("get_from_buffer",
- PyDenseResourceElementsAttribute::getFromBuffer,
- nb::arg("array"), nb::arg("name"), nb::arg("type"),
- nb::arg("alignment") = nb::none(),
- nb::arg("is_mutable") = false, nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
- // clang-format on
- kDenseResourceElementsAttrGetFromBufferDocstring);
+ // Infer alignment to be the stride of one element if not explicit.
+ size_t inferredAlignment;
+ if (alignment)
+ inferredAlignment = *alignment;
+ else
+ inferredAlignment = view->strides[view->ndim - 1];
+
+ // The userData is a Py_buffer* that the deleter owns.
+ auto deleter = [](void *userData, const void *data, size_t size,
+ size_t align) {
+ if (Py_IsFinalizing())
+ return;
+ assert(Py_IsInitialized() && "expected interpreter to be initialized");
+ Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
+ nb::gil_scoped_acquire gil;
+ PyBuffer_Release(ownedView);
+ delete ownedView;
+ };
+
+ size_t rawBufferSize = view->len;
+ MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
+ type, toMlirStringRef(name), view->buf, rawBufferSize, inferredAlignment,
+ isMutable, deleter, static_cast<void *>(view.get()));
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseResourceElementsAttr could not be constructed from the given "
+ "buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
}
-};
+ view.release();
+ return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+}
-class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
- static constexpr const char *pyClassName = "DictAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirDictionaryAttrGetTypeID;
+void PyDenseResourceElementsAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
+ nb::arg("array"), nb::arg("name"), nb::arg("type"),
+ nb::arg("alignment") = nb::none(), nb::arg("is_mutable") = false,
+ nb::arg("context") = nb::none(),
+ // clang-format off
+ nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
+ // clang-format on
+ kDenseResourceElementsAttrGetFromBufferDocstring);
+}
- intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+intptr_t PyDictAttribute::dunderLen() const {
+ return mlirDictionaryAttrGetNumElements(*this);
+}
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(
- mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
- }
+bool PyDictAttribute::dunderContains(const std::string &name) const {
+ return !mlirAttributeIsNull(
+ mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
+}
- static void bindDerived(ClassTy &c) {
- c.def("__contains__", &PyDictAttribute::dunderContains);
- c.def("__len__", &PyDictAttribute::dunderLen);
- c.def_static(
- "get",
- [](const nb::dict &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirNamedAttribute> mlirNamedAttributes;
- mlirNamedAttributes.reserve(attributes.size());
- for (std::pair<nb::handle, nb::handle> it : attributes) {
- auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
- auto name = nb::cast<std::string>(it.first);
- mlirNamedAttributes.push_back(mlirNamedAttributeGet(
- mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
- toMlirStringRef(name)),
- mlirAttr));
- }
+void PyDictAttribute::bindDerived(ClassTy &c) {
+ c.def("__contains__", &PyDictAttribute::dunderContains);
+ c.def("__len__", &PyDictAttribute::dunderLen);
+ c.def_static(
+ "get",
+ [](const nb::dict &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+ mlirNamedAttributes.reserve(attributes.size());
+ for (std::pair<nb::handle, nb::handle> it : attributes) {
+ auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
+ auto name = nb::cast<std::string>(it.first);
+ mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+ mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
+ toMlirStringRef(name)),
+ mlirAttr));
+ }
+ MlirAttribute attr =
+ mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+ mlirNamedAttributes.data());
+ return PyDictAttribute(context->getRef(), attr);
+ },
+ nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
+ "Gets an uniqued dict attribute");
+ c.def("__getitem__",
+ [](PyDictAttribute &self,
+ const std::string &name) -> nb::typed<nb::object, PyAttribute> {
MlirAttribute attr =
- mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
- mlirNamedAttributes.data());
- return PyDictAttribute(context->getRef(), attr);
- },
- nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
- "Gets an uniqued dict attribute");
- c.def("__getitem__",
- [](PyDictAttribute &self,
- const std::string &name) -> nb::typed<nb::object, PyAttribute> {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr))
- throw nb::key_error("attempt to access a non-existent attribute");
- return PyAttribute(self.getContext(), attr).maybeDownCast();
- });
- c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
- if (index < 0 || index >= self.dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data));
- });
- }
-};
-
-/// Refinement of PyDenseElementsAttribute for attributes containing
-/// floating-point values. Supports element access.
-class PyDenseFPElementsAttribute
- : public PyConcreteAttribute<PyDenseFPElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
- static constexpr const char *pyClassName = "DenseFPElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- nb::float_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr))
+ throw nb::key_error("attempt to access a non-existent attribute");
+ return PyAttribute(self.getContext(), attr).maybeDownCast();
+ });
+ c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+ if (index < 0 || index >= self.dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds attribute");
}
+ MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data));
+ });
+}
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::float_ is implicitly constructible
- // from float and double.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAF32(type)) {
- return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
- }
- if (mlirTypeIsAF64(type)) {
- return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
- }
- throw nb::type_error("Unsupported floating-point type");
+nb::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) const {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds element");
}
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nb::float_ is implicitly
+ // constructible from float and double.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAF32(type)) {
+ return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
}
-};
-
-class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
- static constexpr const char *pyClassName = "TypeAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTypeAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const PyType &value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirTypeAttrGet(value.get());
- return PyTypeAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued Type attribute");
- c.def_prop_ro(
- "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
- .maybeDownCast();
- });
+ if (mlirTypeIsAF64(type)) {
+ return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
}
-};
+ throw nb::type_error("Unsupported floating-point type");
+}
-/// Unit Attribute subclass. Unit attributes don't have values.
-class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
- static constexpr const char *pyClassName = "UnitAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnitAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- return PyUnitAttribute(context->getRef(),
- mlirUnitAttrGet(context->get()));
- },
- nb::arg("context") = nb::none(), "Create a Unit attribute.");
- }
-};
+void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+}
-/// Strided layout attribute subclass.
-class PyStridedLayoutAttribute
- : public PyConcreteAttribute<PyStridedLayoutAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
- static constexpr const char *pyClassName = "StridedLayoutAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirStridedLayoutAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](int64_t offset, const std::vector<int64_t> &strides,
- DefaultingPyMlirContext ctx) {
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), offset, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
- "Gets a strided layout attribute.");
- c.def_static(
- "get_fully_dynamic",
- [](int64_t rank, DefaultingPyMlirContext ctx) {
- auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
- std::vector<int64_t> strides(rank);
- llvm::fill(strides, dynamic);
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), dynamic, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("rank"), nb::arg("context") = nb::none(),
- "Gets a strided layout attribute with dynamic offset and strides of "
- "a "
- "given rank.");
- c.def_prop_ro(
- "offset",
- [](PyStridedLayoutAttribute &self) {
- return mlirStridedLayoutAttrGetOffset(self);
- },
- "Returns the value of the float point attribute");
- c.def_prop_ro(
- "strides",
- [](PyStridedLayoutAttribute &self) {
- intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
- std::vector<int64_t> strides(size);
- for (intptr_t i = 0; i < size; i++) {
- strides[i] = mlirStridedLayoutAttrGetStride(self, i);
- }
- return strides;
- },
- "Returns the value of the float point attribute");
- }
-};
+void PyTypeAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirTypeAttrGet(value.get());
+ return PyTypeAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets a uniqued Type attribute");
+ c.def_prop_ro(
+ "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+ .maybeDownCast();
+ });
+}
+
+void PyUnitAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return PyUnitAttribute(context->getRef(),
+ mlirUnitAttrGet(context->get()));
+ },
+ nb::arg("context") = nb::none(), "Create a Unit attribute.");
+}
+
+void PyStridedLayoutAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int64_t offset, const std::vector<int64_t> &strides,
+ DefaultingPyMlirContext ctx) {
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), offset, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
+ "Gets a strided layout attribute.");
+ c.def_static(
+ "get_fully_dynamic",
+ [](int64_t rank, DefaultingPyMlirContext ctx) {
+ auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
+ std::vector<int64_t> strides(rank);
+ llvm::fill(strides, dynamic);
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), dynamic, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nb::arg("rank"), nb::arg("context") = nb::none(),
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
+ "given rank.");
+ c.def_prop_ro(
+ "offset",
+ [](PyStridedLayoutAttribute &self) {
+ return mlirStridedLayoutAttrGetOffset(self);
+ },
+ "Returns the value of the float point attribute");
+ c.def_prop_ro(
+ "strides",
+ [](PyStridedLayoutAttribute &self) {
+ intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+ std::vector<int64_t> strides(size);
+ for (intptr_t i = 0; i < size; i++) {
+ strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+ }
+ return strides;
+ },
+ "Returns the value of the float point attribute");
+}
nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
@@ -1747,10 +1309,6 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
void PyStringAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
@@ -1795,9 +1353,6 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7350046f428c7..12fdf2f8e1ecd 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -25,494 +25,269 @@ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
using llvm::Twine;
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
+int python::mlir::mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerTypeGetTypeID;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create an unsigned integer type");
- c.def_prop_ro(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_prop_ro(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_prop_ro(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_prop_ro(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIndexTypeGetTypeID;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a index type.");
- }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
- static constexpr const char *pyClassName = "FloatType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
- "Returns the width of the floating-point type");
- }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
- : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat4E2M1FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float4E2M1FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
- return PyFloat4E2M1FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
- : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E2M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E2M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
- return PyFloat6E2M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
- : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E3M2FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E3M2FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
- return PyFloat6E3M2FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
- : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
- return PyFloat8E4M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2TypeGet(context->get());
- return PyFloat8E5M2Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3TypeGet(context->get());
- return PyFloat8E4M3Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
- : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
- return PyFloat8E4M3FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
- : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3B11FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
- return PyFloat8E4M3B11FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
- : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2FNUZType";
- using PyConcreteType::PyConcreteType;
+void PyIntegerType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_signless",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create a signless integer type");
+ c.def_static(
+ "get_signed",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create a signed integer type");
+ c.def_static(
+ "get_unsigned",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create an unsigned integer type");
+ c.def_prop_ro(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+ "Returns the width of the integer type");
+ c.def_prop_ro(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self);
+ },
+ "Returns whether this is a signless integer");
+ c.def_prop_ro(
+ "is_signed",
+ [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); },
+ "Returns whether this is a signed integer");
+ c.def_prop_ro(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self);
+ },
+ "Returns whether this is an unsigned integer");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
- return PyFloat8E5M2FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
- }
-};
+void PyIndexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirIndexTypeGet(context->get());
+ return PyIndexType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a index type.");
+}
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E3M4TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E3M4Type";
- using PyConcreteType::PyConcreteType;
+void PyFloatType::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E3M4TypeGet(context->get());
- return PyFloat8E3M4Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
- }
-};
+void PyFloat4E2M1FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+ return PyFloat4E2M1FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
+}
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
- : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E8M0FNUTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E8M0FNUType";
- using PyConcreteType::PyConcreteType;
+void PyFloat6E2M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+ return PyFloat6E2M3FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
- return PyFloat8E8M0FNUType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
- }
-};
+void PyFloat6E3M2FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+ return PyFloat6E3M2FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
+}
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirBFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+ return PyFloat8E4M3FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a bf16 type.");
- }
-};
+void PyFloat8E5M2Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2TypeGet(context->get());
+ return PyFloat8E5M2Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
+}
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3TypeGet(context->get());
+ return PyFloat8E4M3Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f16 type.");
- }
-};
+void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+ return PyFloat8E4M3FNUZType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
+}
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatTF32TypeGetTypeID;
- static constexpr const char *pyClassName = "FloatTF32Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+ return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirTF32TypeGet(context->get());
- return PyTF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a tf32 type.");
- }
-};
+void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+ return PyFloat8E5M2FNUZType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
+}
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat32TypeGetTypeID;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E3M4Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E3M4TypeGet(context->get());
+ return PyFloat8E3M4Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f32 type.");
- }
-};
+void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+ return PyFloat8E8M0FNUType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
+}
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat64TypeGetTypeID;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
+void PyBF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirBF16TypeGet(context->get());
+ return PyBF16Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a bf16 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f64 type.");
- }
-};
+void PyF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF16TypeGet(context->get());
+ return PyF16Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a f16 type.");
+}
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirNoneTypeGetTypeID;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
+void PyTF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirTF32TypeGet(context->get());
+ return PyTF32Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a tf32 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirNoneTypeGet(context->get());
- return PyNoneType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a none type.");
- }
-};
+void PyF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF32TypeGet(context->get());
+ return PyF32Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a f32 type.");
+}
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirComplexTypeGetTypeID;
- static constexpr const char *pyClassName = "ComplexType";
- using PyConcreteType::PyConcreteType;
+void PyF64Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF64TypeGet(context->get());
+ return PyF64Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a f64 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType) {
- // The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType)) {
- MlirType t = mlirComplexTypeGet(elementType);
- return PyComplexType(elementType.getContext(), t);
- }
- throw nb::value_error(
- (Twine("invalid '") +
- nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
- "' and expected floating point or integer type.")
- .str()
- .c_str());
- },
- "Create a complex type");
- c.def_prop_ro(
- "element_type",
- [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
- .maybeDownCast();
- },
- "Returns element type.");
- }
-};
+void PyNoneType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirNoneTypeGet(context->get());
+ return PyNoneType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a none type.");
+}
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+void PyComplexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ MlirType t = mlirComplexTypeGet(elementType);
+ return PyComplexType(elementType.getContext(), t);
+ }
+ throw nb::value_error(
+ (Twine("invalid '") +
+ nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
+ "' and expected floating point or integer type.")
+ .str()
+ .c_str());
+ },
+ "Create a complex type");
+ c.def_prop_ro(
+ "element_type",
+ [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+ .maybeDownCast();
+ },
+ "Returns element type.");
+}
// Shaped Type Interface - ShapedType
void PyShapedType::bindDerived(ClassTy &c) {
@@ -629,522 +404,423 @@ void PyShapedType::requireHasRank() {
const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirVectorTypeGetTypeID;
- static constexpr const char *pyClassName = "VectorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a vector type")
- .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("context") = nb::none(), "Create a vector type")
- .def_prop_ro(
- "scalable",
- [](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_prop_ro("scalable_dims", [](MlirType self) {
- std::vector<bool> scalableDims;
- size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
- scalableDims.reserve(rank);
- for (size_t i = 0; i < rank; ++i)
- scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
- return scalableDims;
- });
- }
-
-private:
- static PyVectorType
- getChecked(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
- }
-
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
-
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else {
- type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- }
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
+void PyVectorType::bindDerived(ClassTy &c) {
+ c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("loc") = nb::none(), "Create a vector type")
+ .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a vector type")
+ .def_prop_ro("scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_prop_ro("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
+}
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyMlirContext context) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
+PyVectorType
+PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
}
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+ } else {
+ type =
+ mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+}
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
-
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else {
- type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
}
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
-};
-
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
- : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirRankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "RankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirRankedTensorTypeGetChecked(
- loc, shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
- "Create a ranked tensor type");
- c.def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirRankedTensorTypeGet(
- shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
- "Create a ranked tensor type");
- c.def_prop_ro(
- "encoding",
- [](PyRankedTensorType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
- if (mlirAttributeIsNull(encoding))
- return std::nullopt;
- return PyAttribute(self.getContext(), encoding).maybeDownCast();
- });
- }
-};
-
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
- : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("loc") = nb::none(),
- "Create a unranked tensor type");
- c.def_static(
- "get_unchecked",
- [](PyType &elementType, DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirUnrankedTensorTypeGet(elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("context") = nb::none(),
- "Create a unranked tensor type");
- }
-};
-
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "MemRefType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
- shape.data(), layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a memref type")
- .def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute layoutAttr =
- layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
- layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(),
- nb::arg("memory_space") = nb::none(),
- nb::arg("context") = nb::none(), "Create a memref type")
- .def_prop_ro(
- "layout",
- [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
- return PyAttribute(self.getContext(),
- mlirMemRefTypeGetLayout(self))
- .maybeDownCast();
- },
- "The layout of the MemRef type.")
- .def(
- "get_strides_and_offset",
- [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
- std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
- int64_t offset;
- if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
- self, strides.data(), &offset)))
- throw std::runtime_error(
- "Failed to extract strides and offset from memref.");
- return {strides, offset};
- },
- "The strides and offset of the MemRef type.")
- .def_prop_ro(
- "affine_map",
- [](PyMemRefType &self) -> PyAffineMap {
- MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
- return PyAffineMap(self.getContext(), map);
- },
- "The layout of the MemRef type as an affine map.")
- .def_prop_ro(
- "memory_space",
- [](PyMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given MemRef type.");
- }
-};
-
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
- : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedMemRefType";
- using PyConcreteType::PyConcreteType;
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyRankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirRankedTensorTypeGetChecked(
+ loc, shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
+ "Create a ranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+ "Create a ranked tensor type");
+ c.def_prop_ro(
+ "encoding",
+ [](PyRankedTensorType &self)
+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
+ MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+ if (mlirAttributeIsNull(encoding))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), encoding).maybeDownCast();
+ });
+}
- MlirType t =
- mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("loc") = nb::none(), "Create a unranked memref type")
- .def_static(
- "get_unchecked",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyUnrankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("loc") = nb::none(),
+ "Create a unranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("context") = nb::none(),
+ "Create a unranked tensor type");
+}
- MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("context") = nb::none(), "Create a unranked memref type")
- .def_prop_ro(
- "memory_space",
- [](PyUnrankedMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given Unranked MemRef type.");
- }
-};
+void PyMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout,
+ PyAttribute *memorySpace, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+ shape.data(), layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
+ nb::arg("loc") = nb::none(), "Create a memref type")
+ .def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a memref type")
+ .def_prop_ro(
+ "layout",
+ [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
+ return PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
+ .maybeDownCast();
+ },
+ "The layout of the MemRef type.")
+ .def(
+ "get_strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+ self, strides.data(), &offset)))
+ throw std::runtime_error(
+ "Failed to extract strides and offset from memref.");
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
+ .def_prop_ro(
+ "affine_map",
+ [](PyMemRefType &self) -> PyAffineMap {
+ MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+ return PyAffineMap(self.getContext(), map);
+ },
+ "The layout of the MemRef type as an affine map.")
+ .def_prop_ro(
+ "memory_space",
+ [](PyMemRefType &self)
+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
+ MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given MemRef type.");
+}
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTupleTypeGetTypeID;
- static constexpr const char *pyClassName = "TupleType";
- using PyConcreteType::PyConcreteType;
+void PyUnrankedMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t =
+ mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("loc") = nb::none(), "Create a unranked memref type")
+ .def_static(
+ "get_unchecked",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("context") = nb::none(), "Create a unranked memref type")
+ .def_prop_ro(
+ "memory_space",
+ [](PyUnrankedMemRefType &self)
+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given Unranked MemRef type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_tuple",
- [](const std::vector<PyType> &elements,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirElements;
- mlirElements.reserve(elements.size());
- for (const auto &element : elements)
- mlirElements.push_back(element.get());
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- mlirElements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- "Create a tuple type");
- c.def_static(
- "get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- elements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- // clang-format off
+void PyTupleType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_tuple",
+ [](const std::vector<PyType> &elements, DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirElements;
+ mlirElements.reserve(elements.size());
+ for (const auto &element : elements)
+ mlirElements.push_back(element.get());
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ mlirElements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nb::arg("elements"), nb::arg("context") = nb::none(),
+ "Create a tuple type");
+ c.def_static(
+ "get_tuple",
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nb::arg("elements"), nb::arg("context") = nb::none(),
+ // clang-format off
nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
- // clang-format on
- "Create a tuple type");
- c.def(
- "get_type",
- [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
- .maybeDownCast();
- },
- nb::arg("pos"), "Returns the pos-th type in the tuple type.");
- c.def_prop_ro(
- "num_types",
- [](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self);
- },
- "Returns the number of types contained in a tuple.");
- }
-};
-
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFunctionTypeGetTypeID;
- static constexpr const char *pyClassName = "FunctionType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirInputs;
- mlirInputs.reserve(inputs.size());
- for (const auto &input : inputs)
- mlirInputs.push_back(input.get());
- std::vector<MlirType> mlirResults;
- mlirResults.reserve(results.size());
- for (const auto &result : results)
- mlirResults.push_back(result.get());
+ // clang-format on
+ "Create a tuple type");
+ c.def(
+ "get_type",
+ [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+ .maybeDownCast();
+ },
+ nb::arg("pos"), "Returns the pos-th type in the tuple type.");
+ c.def_prop_ro(
+ "num_types",
+ [](PyTupleType &self) -> intptr_t {
+ return mlirTupleTypeGetNumTypes(self);
+ },
+ "Returns the number of types contained in a tuple.");
+}
- MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
- mlirInputs.data(), results.size(),
- mlirResults.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- "Gets a FunctionType from a list of input and result types");
- c.def_static(
- "get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
- DefaultingPyMlirContext context) {
- MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- // clang-format off
+void PyFunctionType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<PyType> inputs, std::vector<PyType> results,
+ DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirInputs;
+ mlirInputs.reserve(inputs.size());
+ for (const auto &input : inputs)
+ mlirInputs.push_back(input.get());
+ std::vector<MlirType> mlirResults;
+ mlirResults.reserve(results.size());
+ for (const auto &result : results)
+ mlirResults.push_back(result.get());
+
+ MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
+ mlirInputs.data(), results.size(),
+ mlirResults.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+ "Gets a FunctionType from a list of input and result types");
+ c.def_static(
+ "get",
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+ // clang-format off
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
- // clang-format on
- "Gets a FunctionType from a list of input and result types");
- c.def_prop_ro(
- "inputs",
- [](PyFunctionType &self) {
- MlirType t = self;
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetInput(t, i));
- }
- return types;
- },
- "Returns the list of input types in the FunctionType.");
- c.def_prop_ro(
- "results",
- [](PyFunctionType &self) {
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetResult(self, i));
- }
- return types;
- },
- "Returns the list of result types in the FunctionType.");
- }
-};
-
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueTypeGetTypeID;
- static constexpr const char *pyClassName = "OpaqueType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const std::string &typeData,
- DefaultingPyMlirContext context) {
- MlirType type = mlirOpaqueTypeGet(context->get(),
- toMlirStringRef(dialectNamespace),
- toMlirStringRef(typeData));
- return PyOpaqueType(context->getRef(), type);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"),
- nb::arg("context") = nb::none(),
- "Create an unregistered (opaque) dialect type.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque type as a string.");
- c.def_prop_ro(
- "data",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaque type as a string.");
- }
-};
+ // clang-format on
+ "Gets a FunctionType from a list of input and result types");
+ c.def_prop_ro(
+ "inputs",
+ [](PyFunctionType &self) {
+ MlirType t = self;
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetInput(t, i));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ c.def_prop_ro(
+ "results",
+ [](PyFunctionType &self) {
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetResult(self, i));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+}
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+void PyOpaqueType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const std::string &typeData,
+ DefaultingPyMlirContext context) {
+ MlirType type =
+ mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace),
+ toMlirStringRef(typeData));
+ return PyOpaqueType(context->getRef(), type);
+ },
+ nb::arg("dialect_namespace"), nb::arg("buffer"),
+ nb::arg("context") = nb::none(),
+ "Create an unregistered (opaque) dialect type.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque type as a string.");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaque type as a string.");
+}
namespace mlir {
namespace python {
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b2c9380bc1d73..88f58d45cdd75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -9,7 +9,9 @@
#include "Pass.h"
#include "Rewrite.h"
#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 4a9fb127ee08c..003a06b16daac 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -533,9 +533,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
SOURCES
MainModule.cpp
IRAffine.cpp
- IRAttributes.cpp
IRInterfaces.cpp
- IRTypes.cpp
Pass.cpp
Rewrite.cpp
@@ -846,8 +844,10 @@ declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
- IRCore.cpp
Globals.cpp
+ IRAttributes.cpp
+ IRCore.cpp
+ IRTypes.cpp
)
################################################################################
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 43573cbc305fa..b229c02ccf5e6 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -15,6 +15,7 @@
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -47,6 +48,49 @@ struct PyTestType
}
};
+struct PyTestIntegerRankedTensorType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
+ PyTestIntegerRankedTensorType,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, unsigned width,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return PyTestIntegerRankedTensorType(
+ ctx->getRef(),
+ mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
+ },
+ nb::arg("shape"), nb::arg("width"),
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+struct PyTestTensorValue
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
+ PyTestTensorValue> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAPythonTestTestTensorValue;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestTensorValue";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+ }
+};
+
class PyTestAttr
: public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
PyTestAttr> {
@@ -73,18 +117,18 @@ class PyTestAttr
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
- [](MlirContext context, bool load) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context,
+ bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
- mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+ mlirDialectHandleRegisterDialect(pythonTestDialect,
+ context.get()->get());
if (load) {
- mlirDialectHandleLoadDialect(pythonTestDialect, context);
+ mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get());
}
},
- nb::arg("context"), nb::arg("load") = true,
- // clang-format off
- nb::sig("def register_python_test_dialect(context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none(), nb::arg("load") = true);
m.def(
"register_dialect",
@@ -100,73 +144,16 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"test_diagnostics_with_errors_and_notes",
- [](MlirContext ctx) {
- mlir::python::CollectDiagnosticsToStringScope handler(ctx);
- mlirPythonTestEmitDiagnosticWithNote(ctx);
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
+ mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
- // clang-format off
- nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none());
PyTestAttr::bind(m);
PyTestType::bind(m);
-
- auto typeCls =
- mlir_type_subclass(m, "TestIntegerRankedTensorType",
- mlirTypeIsARankedIntegerTensor,
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("RankedTensorType"))
- .def_classmethod(
- "get",
- [](const nb::object &cls, std::vector<int64_t> shape,
- unsigned width, MlirContext ctx) {
- MlirAttribute encoding = mlirAttributeGetNull();
- return cls(mlirRankedTensorTypeGet(
- shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
- encoding));
- },
- // clang-format off
- nb::sig("def get(cls: object, shape: collections.abc.Sequence[int], width: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
- nb::arg("context").none() = nb::none());
-
- assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
- "TestIntegerRankedTensorType has no static_typeid");
-
- MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID, nb::arg("replace") = true)(
- nanobind::cpp_function([typeCls](const nb::object &mlirType) {
- return typeCls.get_class()(mlirType);
- }));
-
- auto valueCls = mlir_value_subclass(m, "TestTensorValue",
- mlirTypeIsAPythonTestTestTensorValue)
- .def("is_null", [](MlirValue &self) {
- return mlirValueIsNull(self);
- });
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID)(
- nanobind::cpp_function([valueCls](const nb::object &valueObj) {
- std::optional<nb::object> capsule =
- mlirApiObjectToCapsule(valueObj);
- assert(capsule.has_value() && "capsule is not null");
- MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
- MlirType t = mlirValueGetType(v);
- // This is hyper-specific in order to exercise/test registering a
- // value caster from cpp (but only for a single test case; see
- // testTensorValue python_test.py).
- if (mlirShapedTypeHasStaticShape(t) &&
- mlirShapedTypeGetDimSize(t, 0) == 1 &&
- mlirShapedTypeGetDimSize(t, 1) == 2 &&
- mlirShapedTypeGetDimSize(t, 2) == 3)
- return valueCls.get_class()(valueObj);
- return valueObj;
- }));
-}
+ PyTestIntegerRankedTensorType::bind(m);
+ PyTestTensorValue::bind(m);
+}
\ No newline at end of file
More information about the llvm-branch-commits
mailing list