[llvm-branch-commits] [mlir] [mlir][Python] port in-tree dialect extensions to use core PyConcreteType, PyConcreteAttribute (PR #173913)
Maksim Levental via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Dec 29 20:37:47 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/173913
>From d76f6a96eb3b63d614bd74e4adac5f5ff57d7056 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 29 Dec 2025 16:57:05 -0800
Subject: [PATCH 1/2] [mlir][Python] move IRTypes and IRAttributes to public
headers
---
.../mlir/Bindings/Python/IRAttributes.h | 617 +++++
mlir/include/mlir/Bindings/Python/IRCore.h | 17 +-
mlir/include/mlir/Bindings/Python/IRTypes.h | 465 +++-
mlir/lib/Bindings/Python/IRAttributes.cpp | 2244 +++++------------
mlir/lib/Bindings/Python/IRTypes.cpp | 1587 +++++-------
mlir/lib/Bindings/Python/MainModule.cpp | 612 ++++-
mlir/python/CMakeLists.txt | 4 +-
.../python/lib/PythonTestModuleNanobind.cpp | 129 +-
8 files changed, 2959 insertions(+), 2716 deletions(-)
create mode 100644 mlir/include/mlir/Bindings/Python/IRAttributes.h
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
new file mode 100644
index 0000000000000..5362dfc7d64c2
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -0,0 +1,617 @@
+//===- IRAttributes.h - Attribute Interfaces ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+#define MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+
+#include "IRCore.h"
+#include "mlir-c/BuiltinTypes.h"
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+struct nb_buffer_info {
+ void *ptr = nullptr;
+ ssize_t itemsize = 0;
+ ssize_t size = 0;
+ const char *format = nullptr;
+ ssize_t ndim = 0;
+ SmallVector<ssize_t, 4> shape;
+ SmallVector<ssize_t, 4> strides;
+ bool readonly = false;
+
+ nb_buffer_info(
+ void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+ SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+ bool readonly = false,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
+ : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
+ shape(std::move(shape_in)), strides(std::move(strides_in)),
+ readonly(readonly), owned_view(std::move(owned_view_in)) {
+ size = 1;
+ for (ssize_t i = 0; i < ndim; ++i) {
+ size *= shape[i];
+ }
+ }
+
+ explicit nb_buffer_info(Py_buffer *view)
+ : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+ {view->shape, view->shape + view->ndim},
+ // TODO(phawkins): check for null strides
+ {view->strides, view->strides + view->ndim},
+ view->readonly != 0,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
+ view, PyBuffer_Release)) {}
+
+ nb_buffer_info(const nb_buffer_info &) = delete;
+ nb_buffer_info(nb_buffer_info &&) = default;
+ nb_buffer_info &operator=(const nb_buffer_info &) = delete;
+ nb_buffer_info &operator=(nb_buffer_info &&) = default;
+
+private:
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
+};
+
+class MLIR_PYTHON_API_EXPORTED nb_buffer : public nanobind::object {
+ NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
+
+ nb_buffer_info request() const {
+ int flags = PyBUF_STRIDES | PyBUF_FORMAT;
+ auto *view = new Py_buffer();
+ if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
+ delete view;
+ throw nanobind::python_error();
+ }
+ return nb_buffer_info(view);
+ }
+};
+
+template <typename T>
+struct nb_format_descriptor {};
+
+class MLIR_PYTHON_API_EXPORTED PyAffineMapAttribute
+ : public PyConcreteAttribute<PyAffineMapAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
+ static constexpr const char *pyClassName = "AffineMapAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAffineMapAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerSetAttribute
+ : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+ static constexpr const char *pyClassName = "IntegerSetAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerSetAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+template <typename T>
+static T pyTryCast(nanobind::handle object) {
+ try {
+ return nanobind::cast<T>(object);
+ } catch (nanobind::cast_error &err) {
+ std::string msg = std::string("Invalid attribute when attempting to "
+ "create an ArrayAttribute (") +
+ err.what() + ")";
+ throw std::runtime_error(msg.c_str());
+ } catch (std::runtime_error &err) {
+ std::string msg = std::string("Invalid attribute (None?) when attempting "
+ "to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw std::runtime_error(msg.c_str());
+ }
+}
+
+/// A python-wrapped dense array attribute with an element type and a derived
+/// implementation class.
+template <typename EltTy, typename DerivedT>
+class MLIR_PYTHON_API_EXPORTED PyDenseArrayAttribute
+ : public PyConcreteAttribute<DerivedT> {
+public:
+ using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
+
+ /// Iterator over the integer elements of a dense array.
+ class PyDenseArrayIterator {
+ public:
+ PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ /// Return a copy of the iterator.
+ PyDenseArrayIterator dunderIter() { return *this; }
+
+ /// Return the next element.
+ EltTy dunderNext() {
+ // Throw if the index has reached the end.
+ if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
+ throw nanobind::stop_iteration();
+ return DerivedT::getElement(attr.get(), nextIndex++);
+ }
+
+ /// Bind the iterator class.
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
+ .def("__iter__", &PyDenseArrayIterator::dunderIter)
+ .def("__next__", &PyDenseArrayIterator::dunderNext);
+ }
+
+ private:
+ /// The referenced dense array attribute.
+ PyAttribute attr;
+ /// The next index to read.
+ int nextIndex = 0;
+ };
+
+ /// Get the element at the given index.
+ EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
+
+ /// Bind the attribute class.
+ static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
+ // Bind the constructor.
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ c.def_static(
+ "get",
+ [](const nanobind::sequence &py_values, DefaultingPyMlirContext ctx) {
+ std::vector<bool> values;
+ for (nanobind::handle py_value : py_values) {
+ int is_true = PyObject_IsTrue(py_value.ptr());
+ if (is_true < 0) {
+ throw nanobind::python_error();
+ }
+ values.push_back(is_true);
+ }
+ return getAttribute(values, ctx->getRef());
+ },
+ nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued dense array attribute");
+ } else {
+ c.def_static(
+ "get",
+ [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
+ return getAttribute(values, ctx->getRef());
+ },
+ nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued dense array attribute");
+ }
+ // Bind the array methods.
+ c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
+ if (i >= mlirDenseArrayGetNumElements(arr))
+ throw nanobind::index_error("DenseArray index out of range");
+ return arr.getItem(i);
+ });
+ c.def("__len__", [](const DerivedT &arr) {
+ return mlirDenseArrayGetNumElements(arr);
+ });
+ c.def("__iter__",
+ [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
+ c.def("__add__", [](DerivedT &arr, const nanobind::list &extras) {
+ std::vector<EltTy> values;
+ intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
+ values.reserve(numOldElements + nanobind::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ values.push_back(arr.getItem(i));
+ for (nanobind::handle attr : extras)
+ values.push_back(pyTryCast<EltTy>(attr));
+ return getAttribute(values, arr.getContext());
+ });
+ }
+
+private:
+ static DerivedT getAttribute(const std::vector<EltTy> &values,
+ PyMlirContextRef ctx) {
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ std::vector<int> intValues(values.begin(), values.end());
+ MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
+ intValues.data());
+ return DerivedT(ctx, attr);
+ } else {
+ MlirAttribute attr =
+ DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+ return DerivedT(ctx, attr);
+ }
+ }
+};
+
+/// Instantiate the python dense array classes.
+struct PyDenseBoolArrayAttribute
+ : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
+ static constexpr auto getAttribute = mlirDenseBoolArrayGet;
+ static constexpr auto getElement = mlirDenseBoolArrayGetElement;
+ static constexpr const char *pyClassName = "DenseBoolArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI8ArrayAttribute
+ : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
+ static constexpr auto getAttribute = mlirDenseI8ArrayGet;
+ static constexpr auto getElement = mlirDenseI8ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI8ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI16ArrayAttribute
+ : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
+ static constexpr auto getAttribute = mlirDenseI16ArrayGet;
+ static constexpr auto getElement = mlirDenseI16ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI16ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI32ArrayAttribute
+ : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
+ static constexpr auto getAttribute = mlirDenseI32ArrayGet;
+ static constexpr auto getElement = mlirDenseI32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI64ArrayAttribute
+ : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
+ static constexpr auto getAttribute = mlirDenseI64ArrayGet;
+ static constexpr auto getElement = mlirDenseI64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF32ArrayAttribute
+ : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
+ static constexpr auto getAttribute = mlirDenseF32ArrayGet;
+ static constexpr auto getElement = mlirDenseF32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF64ArrayAttribute
+ : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
+ static constexpr auto getAttribute = mlirDenseF64ArrayGet;
+ static constexpr auto getElement = mlirDenseF64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyArrayAttribute
+ : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+ static constexpr const char *pyClassName = "ArrayAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirArrayAttrGetTypeID;
+
+ class PyArrayAttributeIterator {
+ public:
+ PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ PyArrayAttributeIterator &dunderIter() { return *this; }
+
+ nanobind::typed<nanobind::object, PyAttribute> dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+ private:
+ PyAttribute attr;
+ int nextIndex = 0;
+ };
+
+ MlirAttribute getItem(intptr_t i) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Float Point Attribute subclass - FloatAttr.
+class MLIR_PYTHON_API_EXPORTED PyFloatAttribute
+ : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+ static constexpr const char *pyClassName = "FloatAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class MLIR_PYTHON_API_EXPORTED PyIntegerAttribute
+ : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static int64_t toPyInt(PyIntegerAttribute &self);
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class MLIR_PYTHON_API_EXPORTED PyBoolAttribute
+ : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+ static constexpr const char *pyClassName = "BoolAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PySymbolRefAttribute
+ : public PyConcreteAttribute<PySymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
+ static constexpr const char *pyClassName = "SymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFlatSymbolRefAttribute
+ : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+ static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyOpaqueAttribute
+ : public PyConcreteAttribute<PyOpaqueAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
+ static constexpr const char *pyClassName = "OpaqueAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+// TODO: Support construction of string elements.
+class MLIR_PYTHON_API_EXPORTED PyDenseElementsAttribute
+ : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+ static constexpr const char *pyClassName = "DenseElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseElementsAttribute
+ getFromList(const nanobind::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute
+ getFromBuffer(const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr);
+
+ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
+
+ std::unique_ptr<nb_buffer_info> accessBuffer();
+
+ static void bindDerived(ClassTy &c);
+
+ static PyType_Slot slots[];
+
+private:
+ static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
+ static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+
+ static bool isUnsignedIntegerFormat(std::string_view format);
+
+ static bool isSignedIntegerFormat(std::string_view format);
+
+ static MlirType
+ getShapedType(std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ Py_buffer &view);
+
+ static MlirAttribute getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context);
+
+ // There is a complication for boolean numpy arrays, as numpy represents
+ // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
+ // booleans per byte.
+ static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context);
+
+ // This does the opposite transformation of
+ // `getBitpackedAttributeFromBooleanBuffer`
+ std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute();
+
+ template <typename Type>
+ std::unique_ptr<nb_buffer_info>
+ bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ // Prepare the data for the buffer_info.
+ // Buffer is configured for read-only access below.
+ Type *data = static_cast<Type *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ // Prepare the shape for the buffer_info.
+ SmallVector<intptr_t, 4> shape;
+ for (intptr_t i = 0; i < rank; ++i)
+ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+ // Prepare the strides for the buffer_info.
+ SmallVector<intptr_t, 4> strides;
+ if (mlirDenseElementsAttrIsSplat(*this)) {
+ // Splats are special, only the single value is stored.
+ strides.assign(rank, 0);
+ } else {
+ for (intptr_t i = 1; i < rank; ++i) {
+ intptr_t strideFactor = 1;
+ for (intptr_t j = i; j < rank; ++j)
+ strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ strides.push_back(sizeof(Type) * strideFactor);
+ }
+ strides.push_back(sizeof(Type));
+ }
+ const char *format;
+ if (explicitFormat) {
+ format = explicitFormat;
+ } else {
+ format = nb_format_descriptor<Type>::format();
+ }
+ return std::make_unique<nb_buffer_info>(
+ data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
+ /*readonly=*/true);
+ }
+};
+
+PyType_Slot PyDenseElementsAttribute::slots[] = {
+// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
+#if PY_VERSION_HEX >= 0x03090000
+ {Py_bf_getbuffer,
+ reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
+ {Py_bf_releasebuffer,
+ reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
+#endif
+ {0, nullptr},
+};
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseIntElementsAttribute
+ : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+ static constexpr const char *pyClassName = "DenseIntElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
+ nanobind::int_ dunderGetItem(intptr_t pos);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDenseResourceElementsAttribute
+ : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsADenseResourceElements;
+ static constexpr const char *pyClassName = "DenseResourceElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseResourceElementsAttribute
+ getFromBuffer(const nb_buffer &buffer, const std::string &name,
+ const PyType &type, std::optional<size_t> alignment,
+ bool isMutable, DefaultingPyMlirContext contextWrapper);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDictAttribute
+ : public PyConcreteAttribute<PyDictAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+ static constexpr const char *pyClassName = "DictAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirDictionaryAttrGetTypeID;
+
+ intptr_t dunderLen() const { return mlirDictionaryAttrGetNumElements(*this); }
+
+ bool dunderContains(const std::string &name) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseFPElementsAttribute
+ : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+ static constexpr const char *pyClassName = "DenseFPElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ nanobind::float_ dunderGetItem(intptr_t pos);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyTypeAttribute
+ : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+ static constexpr const char *pyClassName = "TypeAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTypeAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unit Attribute subclass. Unit attributes don't have values.
+class MLIR_PYTHON_API_EXPORTED PyUnitAttribute
+ : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+ static constexpr const char *pyClassName = "UnitAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnitAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Strided layout attribute subclass.
+class MLIR_PYTHON_API_EXPORTED PyStridedLayoutAttribute
+ : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+ static constexpr const char *pyClassName = "StridedLayoutAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStridedLayoutAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+MLIR_PYTHON_API_EXPORTED nanobind::object
+denseArrayAttributeCaster(PyAttribute &pyAttribute);
+MLIR_PYTHON_API_EXPORTED nanobind::object
+denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute);
+MLIR_PYTHON_API_EXPORTED nanobind::object
+integerOrBoolAttributeCaster(PyAttribute &pyAttribute);
+MLIR_PYTHON_API_EXPORTED nanobind::object
+symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute);
+
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+#endif
\ No newline at end of file
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 0f402b4ce15ff..340b16bcdf558 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -979,7 +979,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
nanobind::cast<nanobind::callable>(nanobind::cpp_function(
- [](PyType pyType) -> DerivedTy { return pyType; })));
+ [](PyType pyType) -> DerivedTy { return pyType; })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1123,7 +1124,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
nanobind::cast<nanobind::callable>(
nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
- })));
+ })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1511,6 +1513,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
// and redefine bindDerived.
using ClassTy = nanobind::class_<DerivedTy, PyValue>;
using IsAFunctionTy = bool (*)(MlirValue);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef, MlirValue value)
@@ -1553,6 +1557,15 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
return self.maybeDownCast();
});
+
+ if (DerivedTy::getTypeIdFunction) {
+ PyGlobals::get().registerValueCaster(
+ DerivedTy::getTypeIdFunction(),
+ nanobind::cast<nanobind::callable>(nanobind::cpp_function(
+ [](PyValue pyValue) -> DerivedTy { return pyValue; })),
+ /*replace*/ true);
+ }
+
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 87e0e10764bd8..db478e8d33f37 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,13 +9,14 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
+#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Shaped Type Interface - ShapedType
-class MLIR_PYTHON_API_EXPORTED PyShapedType
+class MLIR_PYTHON_API_EXPORTED MLIR_PYTHON_API_EXPORTED PyShapedType
: public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
@@ -27,6 +28,468 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
private:
void requireHasRank();
};
+
+/// Checks whether the given type is an integer or float type.
+inline int mlirTypeIsAIntegerOrFloat(MlirType type) {
+ return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
+ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
+}
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerType
+ : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerTypeGetTypeID;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Index Type subclass - IndexType.
+class MLIR_PYTHON_API_EXPORTED PyIndexType
+ : public PyConcreteType<PyIndexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIndexTypeGetTypeID;
+ static constexpr const char *pyClassName = "IndexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFloatType
+ : public PyConcreteType<PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+ static constexpr const char *pyClassName = "FloatType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float4E2M1FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat4E2M1FNType
+ : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat4E2M1FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float4E2M1FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E2M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E2M3FNType
+ : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E2M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E2M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E3M2FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E3M2FNType
+ : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E3M2FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E3M2FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNType
+ : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2Type
+ : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3Type
+ : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNUZType
+ : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3B11FNUZType
+ : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3B11FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2FNUZType
+ : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E3M4Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E3M4Type
+ : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E3M4TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E3M4Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E8M0FNUType
+ : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E8M0FNUTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E8M0FNUType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - BF16Type.
+class MLIR_PYTHON_API_EXPORTED PyBF16Type
+ : public PyConcreteType<PyBF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirBFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "BF16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F16Type.
+class MLIR_PYTHON_API_EXPORTED PyF16Type
+ : public PyConcreteType<PyF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "F16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - TF32Type.
+class MLIR_PYTHON_API_EXPORTED PyTF32Type
+ : public PyConcreteType<PyTF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatTF32TypeGetTypeID;
+ static constexpr const char *pyClassName = "FloatTF32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F32Type.
+class MLIR_PYTHON_API_EXPORTED PyF32Type
+ : public PyConcreteType<PyF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat32TypeGetTypeID;
+ static constexpr const char *pyClassName = "F32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F64Type.
+class MLIR_PYTHON_API_EXPORTED PyF64Type
+ : public PyConcreteType<PyF64Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat64TypeGetTypeID;
+ static constexpr const char *pyClassName = "F64Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// None Type subclass - NoneType.
+class MLIR_PYTHON_API_EXPORTED PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirNoneTypeGetTypeID;
+ static constexpr const char *pyClassName = "NoneType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Complex Type subclass - ComplexType.
+class MLIR_PYTHON_API_EXPORTED PyComplexType
+ : public PyConcreteType<PyComplexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirComplexTypeGetTypeID;
+ static constexpr const char *pyClassName = "ComplexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Vector Type subclass - VectorType.
+class MLIR_PYTHON_API_EXPORTED PyVectorType
+ : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirVectorTypeGetTypeID;
+ static constexpr const char *pyClassName = "VectorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
+ if (scalable && scalableDims) {
+ throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nanobind::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(
+ llvm::map_range(*scalable, [](const nanobind::handle &h) {
+ return nanobind::cast<bool>(h);
+ }));
+ type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+ scalableDimFlags.data(),
+ elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nanobind::value_error(
+ "Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+ scalableDimFlags.data(),
+ elementType);
+ } else {
+ type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+ elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nanobind::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(
+ llvm::map_range(*scalable, [](const nanobind::handle &h) {
+ return nanobind::cast<bool>(h);
+ }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nanobind::value_error(
+ "Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ }
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyRankedTensorType
+ : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "RankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedTensorType
+ : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class MLIR_PYTHON_API_EXPORTED PyMemRefType
+ : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "MemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedMemRefType
+ : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedMemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Tuple Type subclass - TupleType.
+class MLIR_PYTHON_API_EXPORTED PyTupleType
+ : public PyConcreteType<PyTupleType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTupleTypeGetTypeID;
+ static constexpr const char *pyClassName = "TupleType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Function type.
+class MLIR_PYTHON_API_EXPORTED PyFunctionType
+ : public PyConcreteType<PyFunctionType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFunctionTypeGetTypeID;
+ static constexpr const char *pyClassName = "FunctionType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Opaque Type subclass - OpaqueType.
+class MLIR_PYTHON_API_EXPORTED PyOpaqueType
+ : public PyConcreteType<PyOpaqueType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueTypeGetTypeID;
+ static constexpr const char *pyClassName = "OpaqueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index f0f0ae9ba741e..0591e34e77e47 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -14,6 +14,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
@@ -32,159 +33,10 @@ using llvm::SmallVector;
// Docstrings (trivial, non-duplicated docstrings are included inline).
//------------------------------------------------------------------------------
-static const char kDenseElementsAttrGetDocstring[] =
- R"(Gets a DenseElementsAttr from a Python buffer or array.
-
-When `type` is not provided, then some limited type inferencing is done based
-on the buffer format. Support presently exists for 8/16/32/64 signed and
-unsigned integers and float16/float32/float64. DenseElementsAttrs of these
-types can also be converted back to a corresponding buffer.
-
-For conversions outside of these types, a `type=` must be explicitly provided
-and the buffer contents must be bit-castable to the MLIR internal
-representation:
-
- * Integer types (except for i1): the buffer must be byte aligned to the
- next byte boundary.
- * Floating point types: Must be bit-castable to the given floating point
- size.
- * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
- row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
- this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
-
-If a single element buffer is passed (or for i1, a single byte with value 0
-or 255), then a splat will be created.
-
-Args:
- array: The array or buffer to convert.
- signless: If inferring an appropriate MLIR type, use signless types for
- integers (defaults True).
- type: Skips inference of the MLIR element type and uses this instead. The
- storage size must be consistent with the actual contents of the buffer.
- shape: Overrides the shape of the buffer when constructing the MLIR
- shaped type. This is needed when the physical and logical shape differ (as
- for i1).
- context: Explicit context, if not from context manager.
-
-Returns:
- DenseElementsAttr on success.
-
-Raises:
- ValueError: If the type of the buffer or array cannot be matched to an MLIR
- type or if the buffer does not meet expectations.
-)";
-
-static const char kDenseElementsAttrGetFromListDocstring[] =
- R"(Gets a DenseElementsAttr from a Python list of attributes.
-
-Note that it can be expensive to construct attributes individually.
-For a large number of elements, consider using a Python buffer or array instead.
-
-Args:
- attrs: A list of attributes.
- type: The desired shape and type of the resulting DenseElementsAttr.
- If not provided, the element type is determined based on the type
- of the 0th attribute and the shape is `[len(attrs)]`.
- context: Explicit context, if not from context manager.
-
-Returns:
- DenseElementsAttr on success.
-
-Raises:
- ValueError: If the type of the attributes does not match the type
- specified by `shaped_type`.
-)";
-
-static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
- R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
-
-This function does minimal validation or massaging of the data, and it is
-up to the caller to ensure that the buffer meets the characteristics
-implied by the shape.
-
-The backing buffer and any user objects will be retained for the lifetime
-of the resource blob. This is typically bounded to the context but the
-resource can have a shorter lifespan depending on how it is used in
-subsequent processing.
-
-Args:
- buffer: The array or buffer to convert.
- name: Name to provide to the resource (may be changed upon collision).
- type: The explicit ShapedType to construct the attribute with.
- context: Explicit context, if not from context manager.
-
-Returns:
- DenseResourceElementsAttr on success.
-
-Raises:
- ValueError: If the type of the buffer or array cannot be matched to an MLIR
- type or if the buffer does not meet expectations.
-)";
-
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-struct nb_buffer_info {
- void *ptr = nullptr;
- ssize_t itemsize = 0;
- ssize_t size = 0;
- const char *format = nullptr;
- ssize_t ndim = 0;
- SmallVector<ssize_t, 4> shape;
- SmallVector<ssize_t, 4> strides;
- bool readonly = false;
-
- nb_buffer_info(
- void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
- SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
- bool readonly = false,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
- : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
- shape(std::move(shape_in)), strides(std::move(strides_in)),
- readonly(readonly), owned_view(std::move(owned_view_in)) {
- size = 1;
- for (ssize_t i = 0; i < ndim; ++i) {
- size *= shape[i];
- }
- }
-
- explicit nb_buffer_info(Py_buffer *view)
- : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
- {view->shape, view->shape + view->ndim},
- // TODO(phawkins): check for null strides
- {view->strides, view->strides + view->ndim},
- view->readonly != 0,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
- view, PyBuffer_Release)) {}
-
- nb_buffer_info(const nb_buffer_info &) = delete;
- nb_buffer_info(nb_buffer_info &&) = default;
- nb_buffer_info &operator=(const nb_buffer_info &) = delete;
- nb_buffer_info &operator=(nb_buffer_info &&) = default;
-
-private:
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
-};
-
-class nb_buffer : public nb::object {
- NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
-
- nb_buffer_info request() const {
- int flags = PyBUF_STRIDES | PyBUF_FORMAT;
- auto *view = new Py_buffer();
- if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
- delete view;
- throw nb::python_error();
- }
- return nb_buffer_info(view);
- }
-};
-
-template <typename T>
-struct nb_format_descriptor {};
-
template <>
struct nb_format_descriptor<bool> {
static const char *format() { return "?"; }
@@ -230,1063 +82,602 @@ struct nb_format_descriptor<double> {
static const char *format() { return "d"; }
};
-class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
- static constexpr const char *pyClassName = "AffineMapAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirAffineMapAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyAffineMap &affineMap) {
- MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
- return PyAffineMapAttribute(affineMap.getContext(), attr);
- },
- nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
- c.def_prop_ro(
- "value",
- [](PyAffineMapAttribute &self) {
- return PyAffineMap(self.getContext(),
- mlirAffineMapAttrGetValue(self));
- },
- "Returns the value of the AffineMap attribute");
- }
-};
-
-class PyIntegerSetAttribute
- : public PyConcreteAttribute<PyIntegerSetAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
- static constexpr const char *pyClassName = "IntegerSetAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerSetAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyIntegerSet &integerSet) {
- MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
- return PyIntegerSetAttribute(integerSet.getContext(), attr);
- },
- nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
- }
-};
-
-template <typename T>
-static T pyTryCast(nb::handle object) {
- try {
- return nb::cast<T>(object);
- } catch (nb::cast_error &err) {
- std::string msg = std::string("Invalid attribute when attempting to "
- "create an ArrayAttribute (") +
- err.what() + ")";
- throw std::runtime_error(msg.c_str());
- } catch (std::runtime_error &err) {
- std::string msg = std::string("Invalid attribute (None?) when attempting "
- "to create an ArrayAttribute (") +
- err.what() + ")";
- throw std::runtime_error(msg.c_str());
- }
+nanobind::typed<nanobind::object, PyAttribute>
+PyArrayAttribute::PyArrayAttributeIterator::dunderNext() {
+ // TODO: Throw is an inefficient way to stop iteration.
+ if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
+ throw nanobind::stop_iteration();
+ return PyAttribute(this->attr.getContext(),
+ mlirArrayAttrGetElement(attr.get(), nextIndex++))
+ .maybeDownCast();
}
-/// A python-wrapped dense array attribute with an element type and a derived
-/// implementation class.
-template <typename EltTy, typename DerivedT>
-class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
-public:
- using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
-
- /// Iterator over the integer elements of a dense array.
- class PyDenseArrayIterator {
- public:
- PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- /// Return a copy of the iterator.
- PyDenseArrayIterator dunderIter() { return *this; }
-
- /// Return the next element.
- EltTy dunderNext() {
- // Throw if the index has reached the end.
- if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return DerivedT::getElement(attr.get(), nextIndex++);
- }
-
- /// Bind the iterator class.
- static void bind(nb::module_ &m) {
- nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
- .def("__iter__", &PyDenseArrayIterator::dunderIter)
- .def("__next__", &PyDenseArrayIterator::dunderNext);
- }
-
- private:
- /// The referenced dense array attribute.
- PyAttribute attr;
- /// The next index to read.
- int nextIndex = 0;
- };
-
- /// Get the element at the given index.
- EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
-
- /// Bind the attribute class.
- static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
- // Bind the constructor.
- if constexpr (std::is_same_v<EltTy, bool>) {
- c.def_static(
- "get",
- [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
- std::vector<bool> values;
- for (nb::handle py_value : py_values) {
- int is_true = PyObject_IsTrue(py_value.ptr());
- if (is_true < 0) {
- throw nb::python_error();
- }
- values.push_back(is_true);
- }
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context") = nb::none(),
- "Gets a uniqued dense array attribute");
- } else {
- c.def_static(
- "get",
- [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context") = nb::none(),
- "Gets a uniqued dense array attribute");
- }
- // Bind the array methods.
- c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
- if (i >= mlirDenseArrayGetNumElements(arr))
- throw nb::index_error("DenseArray index out of range");
- return arr.getItem(i);
- });
- c.def("__len__", [](const DerivedT &arr) {
- return mlirDenseArrayGetNumElements(arr);
- });
- c.def("__iter__",
- [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
- c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
- std::vector<EltTy> values;
- intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
- values.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- values.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- values.push_back(pyTryCast<EltTy>(attr));
- return getAttribute(values, arr.getContext());
- });
- }
-
-private:
- static DerivedT getAttribute(const std::vector<EltTy> &values,
- PyMlirContextRef ctx) {
- if constexpr (std::is_same_v<EltTy, bool>) {
- std::vector<int> intValues(values.begin(), values.end());
- MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
- intValues.data());
- return DerivedT(ctx, attr);
- } else {
- MlirAttribute attr =
- DerivedT::getAttribute(ctx->get(), values.size(), values.data());
- return DerivedT(ctx, attr);
- }
- }
-};
-
-/// Instantiate the python dense array classes.
-struct PyDenseBoolArrayAttribute
- : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
- static constexpr auto getAttribute = mlirDenseBoolArrayGet;
- static constexpr auto getElement = mlirDenseBoolArrayGetElement;
- static constexpr const char *pyClassName = "DenseBoolArrayAttr";
- static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI8ArrayAttribute
- : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
- static constexpr auto getAttribute = mlirDenseI8ArrayGet;
- static constexpr auto getElement = mlirDenseI8ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI8ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI16ArrayAttribute
- : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
- static constexpr auto getAttribute = mlirDenseI16ArrayGet;
- static constexpr auto getElement = mlirDenseI16ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI16ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI32ArrayAttribute
- : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
- static constexpr auto getAttribute = mlirDenseI32ArrayGet;
- static constexpr auto getElement = mlirDenseI32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI64ArrayAttribute
- : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
- static constexpr auto getAttribute = mlirDenseI64ArrayGet;
- static constexpr auto getElement = mlirDenseI64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF32ArrayAttribute
- : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
- static constexpr auto getAttribute = mlirDenseF32ArrayGet;
- static constexpr auto getElement = mlirDenseF32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF64ArrayAttribute
- : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
- static constexpr auto getAttribute = mlirDenseF64ArrayGet;
- static constexpr auto getElement = mlirDenseF64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
- static constexpr const char *pyClassName = "ArrayAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirArrayAttrGetTypeID;
-
- class PyArrayAttributeIterator {
- public:
- PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- PyArrayAttributeIterator &dunderIter() { return *this; }
-
- nb::typed<nb::object, PyAttribute> dunderNext() {
- // TODO: Throw is an inefficient way to stop iteration.
- if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return PyAttribute(this->attr.getContext(),
- mlirArrayAttrGetElement(attr.get(), nextIndex++))
- .maybeDownCast();
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
- .def("__iter__", &PyArrayAttributeIterator::dunderIter)
- .def("__next__", &PyArrayAttributeIterator::dunderNext);
- }
-
- private:
- PyAttribute attr;
- int nextIndex = 0;
- };
-
- MlirAttribute getItem(intptr_t i) {
- return mlirArrayAttrGetElement(*this, i);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const nb::list &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(nb::len(attributes));
- for (auto attribute : attributes) {
- mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
- }
- MlirAttribute attr = mlirArrayAttrGet(
- context->get(), mlirAttributes.size(), mlirAttributes.data());
- return PyArrayAttribute(context->getRef(), attr);
- },
- nb::arg("attributes"), nb::arg("context") = nb::none(),
- "Gets a uniqued Array attribute");
- c.def(
- "__getitem__",
- [](PyArrayAttribute &arr,
- intptr_t i) -> nb::typed<nb::object, PyAttribute> {
- if (i >= mlirArrayAttrGetNumElements(arr))
- throw nb::index_error("ArrayAttribute index out of range");
- return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
- })
- .def("__len__",
- [](const PyArrayAttribute &arr) {
- return mlirArrayAttrGetNumElements(arr);
- })
- .def("__iter__", [](const PyArrayAttribute &arr) {
- return PyArrayAttributeIterator(arr);
- });
- c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
- std::vector<MlirAttribute> attributes;
- intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
- attributes.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- attributes.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- attributes.push_back(pyTryCast<PyAttribute>(attr));
- MlirAttribute arrayAttr = mlirArrayAttrGet(
- arr.getContext()->get(), attributes.size(), attributes.data());
- return PyArrayAttribute(arr.getContext(), arrayAttr);
- });
- }
-};
-
-/// Float Point Attribute subclass - FloatAttr.
-class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
- static constexpr const char *pyClassName = "FloatAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, double value, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Invalid attribute", errors.take());
- return PyFloatAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_unchecked",
- [](PyType &type, double value, DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute attr =
- mlirFloatAttrDoubleGet(context.get()->get(), type, value);
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Invalid attribute", errors.take());
- return PyFloatAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_f32",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF32TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a f32 type");
- c.def_static(
- "get_f64",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF64TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a f64 type");
- c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
- "Returns the value of the float attribute");
- c.def("__float__", mlirFloatAttrGetValueDouble,
- "Converts the value of the float attribute to a Python float");
- }
-};
-
-/// Integer Attribute subclass - IntegerAttr.
-class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
- static constexpr const char *pyClassName = "IntegerAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type, value);
- return PyIntegerAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"),
- "Gets an uniqued integer attribute associated to a type");
- c.def_prop_ro("value", toPyInt,
- "Returns the value of the integer attribute");
- c.def("__int__", toPyInt,
- "Converts the value of the integer attribute to a Python int");
- c.def_prop_ro_static(
- "static_typeid",
- [](nb::object & /*class*/) {
- return PyTypeID(mlirIntegerAttrGetTypeID());
- },
- nanobind::sig("def static_typeid(/) -> TypeID"));
- }
-
-private:
- static int64_t toPyInt(PyIntegerAttribute &self) {
- MlirType type = mlirAttributeGetType(self);
- if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
- return mlirIntegerAttrGetValueInt(self);
- if (mlirIntegerTypeIsSigned(type))
- return mlirIntegerAttrGetValueSInt(self);
- return mlirIntegerAttrGetValueUInt(self);
- }
-};
-
-/// Bool Attribute subclass - BoolAttr.
-class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
- static constexpr const char *pyClassName = "BoolAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](bool value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
- return PyBoolAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued bool attribute");
- c.def_prop_ro("value", mlirBoolAttrGetValue,
- "Returns the value of the bool attribute");
- c.def("__bool__", mlirBoolAttrGetValue,
- "Converts the value of the bool attribute to a Python bool");
- }
-};
-
-class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
- static constexpr const char *pyClassName = "SymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
- PyMlirContext &context) {
- if (symbols.empty())
- throw std::runtime_error("SymbolRefAttr must be composed of at least "
- "one symbol.");
- MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
- SmallVector<MlirAttribute, 3> referenceAttrs;
- for (size_t i = 1; i < symbols.size(); ++i) {
- referenceAttrs.push_back(
- mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
- }
- return PySymbolRefAttribute(context.getRef(),
- mlirSymbolRefAttrGet(context.get(), rootSymbol,
- referenceAttrs.size(),
- referenceAttrs.data()));
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::vector<std::string> &symbols,
- DefaultingPyMlirContext context) {
- return PySymbolRefAttribute::fromList(symbols, context.resolve());
- },
- nb::arg("symbols"), nb::arg("context") = nb::none(),
- "Gets a uniqued SymbolRef attribute from a list of symbol names");
- c.def_prop_ro(
- "value",
- [](PySymbolRefAttribute &self) {
- std::vector<std::string> symbols = {
- unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
- for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
- ++i)
- symbols.push_back(
- unwrap(mlirSymbolRefAttrGetRootReference(
- mlirSymbolRefAttrGetNestedReference(self, i)))
- .str());
- return symbols;
- },
- "Returns the value of the SymbolRef attribute as a list[str]");
- }
-};
-
-class PyFlatSymbolRefAttribute
- : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
- static constexpr const char *pyClassName = "FlatSymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
+MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
+ return mlirArrayAttrGetElement(*this, i);
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
- return PyFlatSymbolRefAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued FlatSymbolRef attribute");
- c.def_prop_ro(
- "value",
- [](PyFlatSymbolRefAttribute &self) {
- MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the FlatSymbolRef attribute as a string");
- }
-};
+int64_t PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
+ MlirType type = mlirAttributeGetType(self);
+ if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+ return mlirIntegerAttrGetValueInt(self);
+ if (mlirIntegerTypeIsSigned(type))
+ return mlirIntegerAttrGetValueSInt(self);
+ return mlirIntegerAttrGetValueUInt(self);
+}
-class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
- static constexpr const char *pyClassName = "OpaqueAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueAttrGetTypeID;
+PySymbolRefAttribute
+PySymbolRefAttribute::fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context) {
+ if (symbols.empty())
+ throw std::runtime_error("SymbolRefAttr must be composed of at least "
+ "one symbol.");
+ MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
+ SmallVector<MlirAttribute, 3> referenceAttrs;
+ for (size_t i = 1; i < symbols.size(); ++i) {
+ referenceAttrs.push_back(
+ mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
+ }
+ return PySymbolRefAttribute(context.getRef(),
+ mlirSymbolRefAttrGet(context.get(), rootSymbol,
+ referenceAttrs.size(),
+ referenceAttrs.data()));
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const nb_buffer &buffer,
- PyType &type, DefaultingPyMlirContext context) {
- const nb_buffer_info bufferInfo = buffer.request();
- intptr_t bufferSize = bufferInfo.size;
- MlirAttribute attr = mlirOpaqueAttrGet(
- context->get(), toMlirStringRef(dialectNamespace), bufferSize,
- static_cast<char *>(bufferInfo.ptr), type);
- return PyOpaqueAttribute(context->getRef(), attr);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
- nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
- // clang-format on
- "Gets an Opaque attribute.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque attribute as a string");
- c.def_prop_ro(
- "data",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
- return nb::bytes(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaqued attributes as `bytes`");
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getFromList(const nanobind::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+ const size_t numAttributes = nanobind::len(attributes);
+ if (numAttributes == 0)
+ throw nanobind::value_error("Attributes list must be non-empty.");
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Expected a static ShapedType for the shaped_type parameter: "
+ << nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(*explicitType)));
+ throw nanobind::value_error(message.c_str());
+ }
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
+ }
+
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(numAttributes);
+ for (const nanobind::handle &attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "All attributes must be of the same type and match "
+ << "the type parameter: expected="
+ << nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(shapedType)))
+ << ", but got="
+ << nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(attrType)));
+ throw nanobind::value_error(message.c_str());
+ }
+ }
+
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
+PyDenseElementsAttribute PyDenseElementsAttribute::getFromBuffer(
+ const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper) {
+ // Request a contiguous view. In exotic cases, this will cause a copy.
+ int flags = PyBUF_ND;
+ if (!explicitType) {
+ flags |= PyBUF_FORMAT;
+ }
+ Py_buffer view;
+ if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
+ throw nanobind::python_error();
+ }
+ auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
+
+ MlirContext context = contextWrapper->get();
+ MlirAttribute attr = getAttributeFromBuffer(
+ view, signless, explicitType, std::move(explicitShape), context);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
}
-};
-
-// TODO: Support construction of string elements.
-class PyDenseElementsAttribute
- : public PyConcreteAttribute<PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
- static constexpr const char *pyClassName = "DenseElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseElementsAttribute
- getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
- DefaultingPyMlirContext contextWrapper) {
- const size_t numAttributes = nb::len(attributes);
- if (numAttributes == 0)
- throw nb::value_error("Attributes list must be non-empty.");
-
- MlirType shapedType;
- if (explicitType) {
- if ((!mlirTypeIsAShaped(*explicitType) ||
- !mlirShapedTypeHasStaticShape(*explicitType))) {
-
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "Expected a static ShapedType for the shaped_type parameter: "
- << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
- throw nb::value_error(message.c_str());
- }
- shapedType = *explicitType;
- } else {
- SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
- shapedType = mlirRankedTensorTypeGet(
- shape.size(), shape.data(),
- mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
- mlirAttributeGetNull());
- }
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+}
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(numAttributes);
- for (const nb::handle &attribute : attributes) {
- MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
- MlirType attrType = mlirAttributeGetType(mlirAttribute);
- mlirAttributes.push_back(mlirAttribute);
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr) {
+ auto contextWrapper =
+ PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+ if (!mlirAttributeIsAInteger(elementAttr) &&
+ !mlirAttributeIsAFloat(elementAttr)) {
+ std::string message = "Illegal element type for DenseElementsAttr: ";
+ message.append(nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(elementAttr))));
+ throw nanobind::value_error(message.c_str());
+ }
+ if (!mlirTypeIsAShaped(shapedType) ||
+ !mlirShapedTypeHasStaticShape(shapedType)) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(shapedType))));
+ throw nanobind::value_error(message.c_str());
+ }
+ MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+ MlirType attrType = mlirAttributeGetType(elementAttr);
+ if (!mlirTypeEqual(shapedElementType, attrType)) {
+ std::string message =
+ "Shaped element type and attribute type must be equal: shaped=";
+ message.append(nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(shapedType))));
+ message.append(", element=");
+ message.append(nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(elementAttr))));
+ throw nanobind::value_error(message.c_str());
+ }
+
+ MlirAttribute elements =
+ mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
- if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "All attributes must be of the same type and match "
- << "the type parameter: expected="
- << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
- << ", but got="
- << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
- throw nb::value_error(message.c_str());
- }
+std::unique_ptr<nb_buffer_info> PyDenseElementsAttribute::accessBuffer() {
+ MlirType shapedType = mlirAttributeGetType(*this);
+ MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+ std::string format;
+
+ if (mlirTypeIsAF32(elementType)) {
+ // f32
+ return bufferInfo<float>(shapedType);
+ }
+ if (mlirTypeIsAF64(elementType)) {
+ // f64
+ return bufferInfo<double>(shapedType);
+ }
+ if (mlirTypeIsAF16(elementType)) {
+ // f16
+ return bufferInfo<uint16_t>(shapedType, "e");
+ }
+ if (mlirTypeIsAIndex(elementType)) {
+ // Same as IndexType::kInternalStorageBitWidth
+ return bufferInfo<int64_t>(shapedType);
+ }
+ if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 32) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i32
+ return bufferInfo<int32_t>(shapedType);
+ }
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i32
+ return bufferInfo<uint32_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 64) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i64
+ return bufferInfo<int64_t>(shapedType);
}
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i64
+ return bufferInfo<uint64_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 8) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i8
+ return bufferInfo<int8_t>(shapedType);
+ }
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i8
+ return bufferInfo<uint8_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 16) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i16
+ return bufferInfo<int16_t>(shapedType);
+ }
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i16
+ return bufferInfo<uint16_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 1) {
+ // i1 / bool
+ // We can not send the buffer directly back to Python, because the i1
+ // values are bitpacked within MLIR. We call numpy's unpackbits function
+ // to convert the bytes.
+ return getBooleanBufferFromBitpackedAttribute();
+ }
+
+ // TODO: Currently crashes the program.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ throw std::invalid_argument(
+ "unsupported data type for conversion to Python buffer");
+}
- MlirAttribute elements = mlirDenseElementsAttrGet(
- shapedType, mlirAttributes.size(), mlirAttributes.data());
+// Check if the python version is less than 3.13. Py_IsFinalizing is a part
+// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
+#if PY_VERSION_HEX < 0x030d0000
+#define Py_IsFinalizing _Py_IsFinalizing
+#endif
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
- }
+bool PyDenseElementsAttribute::isUnsignedIntegerFormat(
+ std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+ code == 'Q';
+}
- static PyDenseElementsAttribute
- getFromBuffer(const nb_buffer &array, bool signless,
- const std::optional<PyType> &explicitType,
- std::optional<std::vector<int64_t>> explicitShape,
- DefaultingPyMlirContext contextWrapper) {
- // Request a contiguous view. In exotic cases, this will cause a copy.
- int flags = PyBUF_ND;
- if (!explicitType) {
- flags |= PyBUF_FORMAT;
- }
- Py_buffer view;
- if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
- throw nb::python_error();
- }
- auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
+bool PyDenseElementsAttribute::isSignedIntegerFormat(std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+ code == 'q';
+}
- MlirContext context = contextWrapper->get();
- MlirAttribute attr = getAttributeFromBuffer(
- view, signless, explicitType, std::move(explicitShape), context);
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseElementsAttr could not be constructed from the given buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+MlirType PyDenseElementsAttribute::getShapedType(
+ std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape, Py_buffer &view) {
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(view.shape, view.shape + view.ndim);
}
- static PyDenseElementsAttribute getSplat(const PyType &shapedType,
- PyAttribute &elementAttr) {
- auto contextWrapper =
- PyMlirContext::forContext(mlirTypeGetContext(shapedType));
- if (!mlirAttributeIsAInteger(elementAttr) &&
- !mlirAttributeIsAFloat(elementAttr)) {
- std::string message = "Illegal element type for DenseElementsAttr: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
- throw nb::value_error(message.c_str());
- }
- if (!mlirTypeIsAShaped(shapedType) ||
- !mlirShapedTypeHasStaticShape(shapedType)) {
- std::string message =
- "Expected a static ShapedType for the shaped_type parameter: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
- throw nb::value_error(message.c_str());
- }
- MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
- MlirType attrType = mlirAttributeGetType(elementAttr);
- if (!mlirTypeEqual(shapedElementType, attrType)) {
- std::string message =
- "Shaped element type and attribute type must be equal: shaped=";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
- message.append(", element=");
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
- throw nb::value_error(message.c_str());
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+ if (explicitShape) {
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
}
-
- MlirAttribute elements =
- mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ return *bulkLoadElementType;
}
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+}
- intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
-
- std::unique_ptr<nb_buffer_info> accessBuffer() {
- MlirType shapedType = mlirAttributeGetType(*this);
- MlirType elementType = mlirShapedTypeGetElementType(shapedType);
- std::string format;
-
- if (mlirTypeIsAF32(elementType)) {
+MlirAttribute PyDenseElementsAttribute::getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context) {
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes exotics types which do not have a direct
+ // representation in the buffer protocol (i.e. complex, etc).
+ std::optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
// f32
- return bufferInfo<float>(shapedType);
- }
- if (mlirTypeIsAF64(elementType)) {
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
// f64
- return bufferInfo<double>(shapedType);
- }
- if (mlirTypeIsAF16(elementType)) {
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
// f16
- return bufferInfo<uint16_t>(shapedType, "e");
- }
- if (mlirTypeIsAIndex(elementType)) {
- // Same as IndexType::kInternalStorageBitWidth
- return bufferInfo<int64_t>(shapedType);
- }
- if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 32) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (format == "?") {
+ // i1
+ // The i1 type needs to be bit-packed, so we will handle it separately
+ return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+ context);
+ } else if (isSignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
// i32
- return bufferInfo<int32_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i32
- return bufferInfo<uint32_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 64) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ } else if (view.itemsize == 8) {
// i64
- return bufferInfo<int64_t>(shapedType);
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
}
- if (mlirIntegerTypeIsUnsigned(elementType)) {
+ } else if (isUnsignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // unsigned i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ } else if (view.itemsize == 8) {
// unsigned i64
- return bufferInfo<uint64_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 8) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (view.itemsize == 1) {
// i8
- return bufferInfo<int8_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i8
- return bufferInfo<uint8_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 16) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (view.itemsize == 2) {
// i16
- return bufferInfo<int16_t>(shapedType);
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i16
- return bufferInfo<uint16_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 1) {
- // i1 / bool
- // We can not send the buffer directly back to Python, because the i1
- // values are bitpacked within MLIR. We call numpy's unpackbits function
- // to convert the bytes.
- return getBooleanBufferFromBitpackedAttribute();
}
-
- // TODO: Currently crashes the program.
- // Reported as https://github.com/pybind/pybind11/issues/3336
- throw std::invalid_argument(
- "unsupported data type for conversion to Python buffer");
+ if (!bulkLoadElementType) {
+ throw std::invalid_argument(
+ std::string("unimplemented array format conversion from format: ") +
+ std::string(format));
+ }
}
- static void bindDerived(ClassTy &c) {
-#if PY_VERSION_HEX < 0x03090000
- PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
- tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
- tp->tp_as_buffer->bf_releasebuffer =
- PyDenseElementsAttribute::bf_releasebuffer;
-#endif
- c.def("__len__", &PyDenseElementsAttribute::dunderLen)
- .def_static(
- "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
- nb::arg("signless") = true, nb::arg("type") = nb::none(),
- nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
- // clang-format on
- kDenseElementsAttrGetDocstring)
- .def_static("get", PyDenseElementsAttribute::getFromList,
- nb::arg("attrs"), nb::arg("type") = nb::none(),
- nb::arg("context") = nb::none(),
- kDenseElementsAttrGetFromListDocstring)
- .def_static("get_splat", PyDenseElementsAttribute::getSplat,
- nb::arg("shaped_type"), nb::arg("element_attr"),
- "Gets a DenseElementsAttr where all values are the same")
- .def_prop_ro("is_splat",
- [](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self);
- })
- .def("get_splat_value",
- [](PyDenseElementsAttribute &self)
- -> nb::typed<nb::object, PyAttribute> {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw nb::value_error(
- "get_splat_value called on a non-splat attribute");
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self))
- .maybeDownCast();
- });
- }
+ MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+ return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+}
- static PyType_Slot slots[];
+MlirAttribute PyDenseElementsAttribute::getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context) {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nanobind::type_error("Constructing a bit-packed MLIR attribute is "
+ "unsupported on big-endian systems");
+ }
+ nanobind::ndarray<uint8_t, nanobind::numpy, nanobind::ndim<1>,
+ nanobind::c_contig>
+ unpackedArray(
+ /*data=*/static_cast<uint8_t *>(view.buf),
+ /*shape=*/{static_cast<size_t>(view.len)});
+
+ nanobind::module_ numpy = nanobind::module_::import_("numpy");
+ nanobind::object packbitsFunc = numpy.attr("packbits");
+ nanobind::object packedBooleans =
+ packbitsFunc(nanobind::cast(unpackedArray), "bitorder"_a = "little");
+ nb_buffer_info pythonBuffer =
+ nanobind::cast<nb_buffer>(packedBooleans).request();
+
+ MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
+ std::move(explicitShape), view);
+ assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
+ // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
+ // packedBooleans, hence the MlirAttribute will remain valid even when
+ // packedBooleans get reclaimed by the end of the function.
+ return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+ pythonBuffer.ptr);
+}
-private:
- static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
- static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+std::unique_ptr<nb_buffer_info>
+PyDenseElementsAttribute::getBooleanBufferFromBitpackedAttribute() {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nanobind::type_error(
+ "Constructing a numpy array from a MLIR attribute "
+ "is unsupported on big-endian systems");
+ }
+
+ int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+ int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
+ uint8_t *bitpackedData = static_cast<uint8_t *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ nanobind::ndarray<uint8_t, nanobind::numpy, nanobind::ndim<1>,
+ nanobind::c_contig>
+ packedArray(
+ /*data=*/bitpackedData,
+ /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
+
+ nanobind::module_ numpy = nanobind::module_::import_("numpy");
+ nanobind::object unpackbitsFunc = numpy.attr("unpackbits");
+ nanobind::object equalFunc = numpy.attr("equal");
+ nanobind::object reshapeFunc = numpy.attr("reshape");
+ nanobind::object unpackedBooleans =
+ unpackbitsFunc(nanobind::cast(packedArray), "bitorder"_a = "little");
+
+ // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
+ // We need to:
+ // 1. Slice away the padded bits
+ // 2. Make the boolean array have the correct shape
+ // 3. Convert the array to a boolean array
+ unpackedBooleans = unpackedBooleans[nanobind::slice(
+ nanobind::int_(0), nanobind::int_(numBooleans), nanobind::int_(1))];
+ unpackedBooleans = equalFunc(unpackedBooleans, 1);
+
+ MlirType shapedType = mlirAttributeGetType(*this);
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ std::vector<intptr_t> shape(rank);
+ for (intptr_t i = 0; i < rank; ++i) {
+ shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
+ }
+ unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+
+ // Make sure the returned nanobind::buffer_view claims ownership of the data
+ // in `pythonBuffer` so it remains valid when Python reads it
+ nb_buffer pythonBuffer = nanobind::cast<nb_buffer>(unpackedBooleans);
+ return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+}
- static bool isUnsignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
- code == 'Q';
+nanobind::int_ PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds element");
}
- static bool isSignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
- code == 'q';
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Index type can also appear as a DenseIntElementsAttr and therefore can be
+ // casted to integer.
+ assert(mlirTypeIsAInteger(type) ||
+ mlirTypeIsAIndex(type) && "expected integer/index element type in "
+ "dense int elements attribute");
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nanobind::int_ is implicitly
+ // constructible from any C++ integral type and handles bitwidth correctly.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAIndex(type)) {
+ return nanobind::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
}
-
- static MlirType
- getShapedType(std::optional<MlirType> bulkLoadElementType,
- std::optional<std::vector<int64_t>> explicitShape,
- Py_buffer &view) {
- SmallVector<int64_t> shape;
- if (explicitShape) {
- shape.append(explicitShape->begin(), explicitShape->end());
- } else {
- shape.append(view.shape, view.shape + view.ndim);
+ unsigned width = mlirIntegerTypeGetWidth(type);
+ bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+ if (isUnsigned) {
+ if (width == 1) {
+ return nanobind::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
-
- if (mlirTypeIsAShaped(*bulkLoadElementType)) {
- if (explicitShape) {
- throw std::invalid_argument("Shape can only be specified explicitly "
- "when the type is not a shaped type.");
- }
- return *bulkLoadElementType;
+ if (width == 8) {
+ return nanobind::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
}
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- return mlirRankedTensorTypeGet(shape.size(), shape.data(),
- *bulkLoadElementType, encodingAttr);
- }
-
- static MlirAttribute getAttributeFromBuffer(
- Py_buffer &view, bool signless, std::optional<PyType> explicitType,
- const std::optional<std::vector<int64_t>> &explicitShape,
- MlirContext &context) {
- // Detect format codes that are suitable for bulk loading. This includes
- // all byte aligned integer and floating point types up to 8 bytes.
- // Notably, this excludes exotics types which do not have a direct
- // representation in the buffer protocol (i.e. complex, etc).
- std::optional<MlirType> bulkLoadElementType;
- if (explicitType) {
- bulkLoadElementType = *explicitType;
- } else {
- std::string_view format(view.format);
- if (format == "f") {
- // f32
- assert(view.itemsize == 4 && "mismatched array itemsize");
- bulkLoadElementType = mlirF32TypeGet(context);
- } else if (format == "d") {
- // f64
- assert(view.itemsize == 8 && "mismatched array itemsize");
- bulkLoadElementType = mlirF64TypeGet(context);
- } else if (format == "e") {
- // f16
- assert(view.itemsize == 2 && "mismatched array itemsize");
- bulkLoadElementType = mlirF16TypeGet(context);
- } else if (format == "?") {
- // i1
- // The i1 type needs to be bit-packed, so we will handle it separately
- return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
- context);
- } else if (isSignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeSignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeSignedGet(context, 16);
- }
- } else if (isUnsignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // unsigned i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // unsigned i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeUnsignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeUnsignedGet(context, 16);
- }
- }
- if (!bulkLoadElementType) {
- throw std::invalid_argument(
- std::string("unimplemented array format conversion from format: ") +
- std::string(format));
- }
+ if (width == 16) {
+ return nanobind::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
}
-
- MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
- return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
- }
-
- // There is a complication for boolean numpy arrays, as numpy represents
- // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
- // booleans per byte.
- static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
- Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
- MlirContext &context) {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a bit-packed MLIR attribute is "
- "unsupported on big-endian systems");
+ if (width == 32) {
+ return nanobind::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
}
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
- /*data=*/static_cast<uint8_t *>(view.buf),
- /*shape=*/{static_cast<size_t>(view.len)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object packbitsFunc = numpy.attr("packbits");
- nb::object packedBooleans =
- packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
- nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
-
- MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
- std::move(explicitShape), view);
- assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
- // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
- // packedBooleans, hence the MlirAttribute will remain valid even when
- // packedBooleans get reclaimed by the end of the function.
- return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
- pythonBuffer.ptr);
- }
-
- // This does the opposite transformation of
- // `getBitpackedAttributeFromBooleanBuffer`
- std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a numpy array from a MLIR attribute "
- "is unsupported on big-endian systems");
+ if (width == 64) {
+ return nanobind::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
}
+ } else {
+ if (width == 1) {
+ return nanobind::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
+ }
+ if (width == 8) {
+ return nanobind::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
+ }
+ if (width == 16) {
+ return nanobind::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
+ }
+ if (width == 32) {
+ return nanobind::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
+ }
+ if (width == 64) {
+ return nanobind::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
+ }
+ }
+ throw nanobind::type_error("Unsupported integer type");
+}
- int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
- int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
- uint8_t *bitpackedData = static_cast<uint8_t *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
- /*data=*/bitpackedData,
- /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object unpackbitsFunc = numpy.attr("unpackbits");
- nb::object equalFunc = numpy.attr("equal");
- nb::object reshapeFunc = numpy.attr("reshape");
- nb::object unpackedBooleans =
- unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
+PyDenseResourceElementsAttribute
+PyDenseResourceElementsAttribute::getFromBuffer(
+ const nb_buffer &buffer, const std::string &name, const PyType &type,
+ std::optional<size_t> alignment, bool isMutable,
+ DefaultingPyMlirContext contextWrapper) {
+ if (!mlirTypeIsAShaped(type)) {
+ throw std::invalid_argument(
+ "Constructing a DenseResourceElementsAttr requires a ShapedType.");
+ }
+
+ // Do not request any conversions as we must ensure to use caller
+ // managed memory.
+ int flags = PyBUF_STRIDES;
+ std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
+ if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
+ throw nanobind::python_error();
+ }
+
+ // This scope releaser will only release if we haven't yet transferred
+ // ownership.
+ auto freeBuffer = llvm::make_scope_exit([&]() {
+ if (view)
+ PyBuffer_Release(view.get());
+ });
+
+ if (!PyBuffer_IsContiguous(view.get(), 'A')) {
+ throw std::invalid_argument("Contiguous buffer is required.");
+ }
+
+ // Infer alignment to be the stride of one element if not explicit.
+ size_t inferredAlignment;
+ if (alignment)
+ inferredAlignment = *alignment;
+ else
+ inferredAlignment = view->strides[view->ndim - 1];
+
+ // The userData is a Py_buffer* that the deleter owns.
+ auto deleter = [](void *userData, const void *data, size_t size,
+ size_t align) {
+ if (Py_IsFinalizing())
+ return;
+ assert(Py_IsInitialized() && "expected interpreter to be initialized");
+ Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
+ nanobind::gil_scoped_acquire gil;
+ PyBuffer_Release(ownedView);
+ delete ownedView;
+ };
- // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
- // We need to:
- // 1. Slice away the padded bits
- // 2. Make the boolean array have the correct shape
- // 3. Convert the array to a boolean array
- unpackedBooleans = unpackedBooleans[nb::slice(
- nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
- unpackedBooleans = equalFunc(unpackedBooleans, 1);
+ size_t rawBufferSize = view->len;
+ MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
+ type, toMlirStringRef(name), view->buf, rawBufferSize, inferredAlignment,
+ isMutable, deleter, static_cast<void *>(view.get()));
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseResourceElementsAttr could not be constructed from the given "
+ "buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
+ }
+ view.release();
+ return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+}
- MlirType shapedType = mlirAttributeGetType(*this);
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- std::vector<intptr_t> shape(rank);
- for (intptr_t i = 0; i < rank; ++i) {
- shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
- }
- unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+bool PyDictAttribute::dunderContains(const std::string &name) const {
+ return !mlirAttributeIsNull(
+ mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
+}
- // Make sure the returned nb::buffer_view claims ownership of the data in
- // `pythonBuffer` so it remains valid when Python reads it
- nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
- return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+nanobind::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds element");
}
- template <typename Type>
- std::unique_ptr<nb_buffer_info>
- bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- // Prepare the data for the buffer_info.
- // Buffer is configured for read-only access below.
- Type *data = static_cast<Type *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- // Prepare the shape for the buffer_info.
- SmallVector<intptr_t, 4> shape;
- for (intptr_t i = 0; i < rank; ++i)
- shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
- // Prepare the strides for the buffer_info.
- SmallVector<intptr_t, 4> strides;
- if (mlirDenseElementsAttrIsSplat(*this)) {
- // Splats are special, only the single value is stored.
- strides.assign(rank, 0);
- } else {
- for (intptr_t i = 1; i < rank; ++i) {
- intptr_t strideFactor = 1;
- for (intptr_t j = i; j < rank; ++j)
- strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
- strides.push_back(sizeof(Type) * strideFactor);
- }
- strides.push_back(sizeof(Type));
- }
- const char *format;
- if (explicitFormat) {
- format = explicitFormat;
- } else {
- format = nb_format_descriptor<Type>::format();
- }
- return std::make_unique<nb_buffer_info>(
- data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
- /*readonly=*/true);
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nanobind::float_ is implicitly
+ // constructible from float and double.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAF32(type)) {
+ return nanobind::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
}
-}; // namespace
-
-PyType_Slot PyDenseElementsAttribute::slots[] = {
-// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
-#if PY_VERSION_HEX >= 0x03090000
- {Py_bf_getbuffer,
- reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
- {Py_bf_releasebuffer,
- reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
-#endif
- {0, nullptr},
-};
+ if (mlirTypeIsAF64(type)) {
+ return nanobind::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
+ }
+ throw nanobind::type_error("Unsupported floating-point type");
+}
/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
Py_buffer *view,
@@ -1294,15 +685,17 @@ PyType_Slot PyDenseElementsAttribute::slots[] = {
view->obj = nullptr;
std::unique_ptr<nb_buffer_info> info;
try {
- auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
+ auto *attr =
+ nanobind::cast<PyDenseElementsAttribute *>(nanobind::handle(obj));
info = attr->accessBuffer();
- } catch (nb::python_error &e) {
+ } catch (nanobind::python_error &e) {
e.restore();
- nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
+ nanobind::chain_error(PyExc_BufferError,
+ "Error converting attribute to buffer");
return -1;
} catch (std::exception &e) {
- nb::chain_error(PyExc_BufferError,
- "Error converting attribute to buffer: %s", e.what());
+ nanobind::chain_error(PyExc_BufferError,
+ "Error converting attribute to buffer: %s", e.what());
return -1;
}
view->obj = obj;
@@ -1333,523 +726,64 @@ PyType_Slot PyDenseElementsAttribute::slots[] = {
delete reinterpret_cast<nb_buffer_info *>(view->internal);
}
-/// Refinement of the PyDenseElementsAttribute for attributes containing
-/// integer (and boolean) values. Supports element access.
-class PyDenseIntElementsAttribute
- : public PyConcreteAttribute<PyDenseIntElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
- static constexpr const char *pyClassName = "DenseIntElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- /// Returns the element at the given linear position. Asserts if the index
- /// is out of range.
- nb::int_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
- }
-
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Index type can also appear as a DenseIntElementsAttr and therefore can be
- // casted to integer.
- assert(mlirTypeIsAInteger(type) ||
- mlirTypeIsAIndex(type) && "expected integer/index element type in "
- "dense int elements attribute");
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::int_ is implicitly constructible
- // from any C++ integral type and handles bitwidth correctly.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAIndex(type)) {
- return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
- }
- unsigned width = mlirIntegerTypeGetWidth(type);
- bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
- if (isUnsigned) {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
- }
- } else {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
- }
- }
- throw nb::type_error("Unsupported integer type");
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
- }
-};
-
-// Check if the python version is less than 3.13. Py_IsFinalizing is a part
-// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
-#if PY_VERSION_HEX < 0x030d0000
-#define Py_IsFinalizing _Py_IsFinalizing
-#endif
-
-class PyDenseResourceElementsAttribute
- : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction =
- mlirAttributeIsADenseResourceElements;
- static constexpr const char *pyClassName = "DenseResourceElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseResourceElementsAttribute
- getFromBuffer(const nb_buffer &buffer, const std::string &name,
- const PyType &type, std::optional<size_t> alignment,
- bool isMutable, DefaultingPyMlirContext contextWrapper) {
- if (!mlirTypeIsAShaped(type)) {
- throw std::invalid_argument(
- "Constructing a DenseResourceElementsAttr requires a ShapedType.");
- }
-
- // Do not request any conversions as we must ensure to use caller
- // managed memory.
- int flags = PyBUF_STRIDES;
- std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
- if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
- throw nb::python_error();
- }
-
- // This scope releaser will only release if we haven't yet transferred
- // ownership.
- auto freeBuffer = llvm::make_scope_exit([&]() {
- if (view)
- PyBuffer_Release(view.get());
- });
-
- if (!PyBuffer_IsContiguous(view.get(), 'A')) {
- throw std::invalid_argument("Contiguous buffer is required.");
- }
-
- // Infer alignment to be the stride of one element if not explicit.
- size_t inferredAlignment;
- if (alignment)
- inferredAlignment = *alignment;
- else
- inferredAlignment = view->strides[view->ndim - 1];
-
- // The userData is a Py_buffer* that the deleter owns.
- auto deleter = [](void *userData, const void *data, size_t size,
- size_t align) {
- if (Py_IsFinalizing())
- return;
- assert(Py_IsInitialized() && "expected interpreter to be initialized");
- Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
- nb::gil_scoped_acquire gil;
- PyBuffer_Release(ownedView);
- delete ownedView;
- };
-
- size_t rawBufferSize = view->len;
- MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
- type, toMlirStringRef(name), view->buf, rawBufferSize,
- inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseResourceElementsAttr could not be constructed from the given "
- "buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- view.release();
- return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get_from_buffer",
- PyDenseResourceElementsAttribute::getFromBuffer,
- nb::arg("array"), nb::arg("name"), nb::arg("type"),
- nb::arg("alignment") = nb::none(),
- nb::arg("is_mutable") = false, nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
- // clang-format on
- kDenseResourceElementsAttrGetFromBufferDocstring);
- }
-};
-
-class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
- static constexpr const char *pyClassName = "DictAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirDictionaryAttrGetTypeID;
-
- intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
-
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(
- mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__contains__", &PyDictAttribute::dunderContains);
- c.def("__len__", &PyDictAttribute::dunderLen);
- c.def_static(
- "get",
- [](const nb::dict &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirNamedAttribute> mlirNamedAttributes;
- mlirNamedAttributes.reserve(attributes.size());
- for (std::pair<nb::handle, nb::handle> it : attributes) {
- auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
- auto name = nb::cast<std::string>(it.first);
- mlirNamedAttributes.push_back(mlirNamedAttributeGet(
- mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
- toMlirStringRef(name)),
- mlirAttr));
- }
- MlirAttribute attr =
- mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
- mlirNamedAttributes.data());
- return PyDictAttribute(context->getRef(), attr);
- },
- nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
- "Gets an uniqued dict attribute");
- c.def("__getitem__",
- [](PyDictAttribute &self,
- const std::string &name) -> nb::typed<nb::object, PyAttribute> {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr))
- throw nb::key_error("attempt to access a non-existent attribute");
- return PyAttribute(self.getContext(), attr).maybeDownCast();
- });
- c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
- if (index < 0 || index >= self.dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data));
- });
- }
-};
-
-/// Refinement of PyDenseElementsAttribute for attributes containing
-/// floating-point values. Supports element access.
-class PyDenseFPElementsAttribute
- : public PyConcreteAttribute<PyDenseFPElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
- static constexpr const char *pyClassName = "DenseFPElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- nb::float_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
- }
-
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::float_ is implicitly constructible
- // from float and double.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAF32(type)) {
- return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
- }
- if (mlirTypeIsAF64(type)) {
- return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
- }
- throw nb::type_error("Unsupported floating-point type");
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
- }
-};
-
-class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
- static constexpr const char *pyClassName = "TypeAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTypeAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const PyType &value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirTypeAttrGet(value.get());
- return PyTypeAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued Type attribute");
- c.def_prop_ro(
- "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
- .maybeDownCast();
- });
- }
-};
-
-/// Unit Attribute subclass. Unit attributes don't have values.
-class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
- static constexpr const char *pyClassName = "UnitAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnitAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- return PyUnitAttribute(context->getRef(),
- mlirUnitAttrGet(context->get()));
- },
- nb::arg("context") = nb::none(), "Create a Unit attribute.");
- }
-};
-
-/// Strided layout attribute subclass.
-class PyStridedLayoutAttribute
- : public PyConcreteAttribute<PyStridedLayoutAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
- static constexpr const char *pyClassName = "StridedLayoutAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirStridedLayoutAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](int64_t offset, const std::vector<int64_t> &strides,
- DefaultingPyMlirContext ctx) {
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), offset, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
- "Gets a strided layout attribute.");
- c.def_static(
- "get_fully_dynamic",
- [](int64_t rank, DefaultingPyMlirContext ctx) {
- auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
- std::vector<int64_t> strides(rank);
- llvm::fill(strides, dynamic);
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), dynamic, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("rank"), nb::arg("context") = nb::none(),
- "Gets a strided layout attribute with dynamic offset and strides of "
- "a "
- "given rank.");
- c.def_prop_ro(
- "offset",
- [](PyStridedLayoutAttribute &self) {
- return mlirStridedLayoutAttrGetOffset(self);
- },
- "Returns the value of the float point attribute");
- c.def_prop_ro(
- "strides",
- [](PyStridedLayoutAttribute &self) {
- intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
- std::vector<int64_t> strides(size);
- for (intptr_t i = 0; i < size; i++) {
- strides[i] = mlirStridedLayoutAttrGetStride(self, i);
- }
- return strides;
- },
- "Returns the value of the float point attribute");
- }
-};
-
-nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
+nanobind::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
+ return nanobind::cast(PyDenseBoolArrayAttribute(pyAttribute));
if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
+ return nanobind::cast(PyDenseI8ArrayAttribute(pyAttribute));
if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
+ return nanobind::cast(PyDenseI16ArrayAttribute(pyAttribute));
if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
+ return nanobind::cast(PyDenseI32ArrayAttribute(pyAttribute));
if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
+ return nanobind::cast(PyDenseI64ArrayAttribute(pyAttribute));
if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
+ return nanobind::cast(PyDenseF32ArrayAttribute(pyAttribute));
if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
+ return nanobind::cast(PyDenseF64ArrayAttribute(pyAttribute));
std::string msg =
std::string("Can't cast unknown element type DenseArrayAttr (") +
- nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
- throw nb::type_error(msg.c_str());
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(pyAttribute))) +
+ ")";
+ throw nanobind::type_error(msg.c_str());
}
-nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
+nanobind::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
+ return nanobind::cast(PyDenseFPElementsAttribute(pyAttribute));
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
- return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
+ return nanobind::cast(PyDenseIntElementsAttribute(pyAttribute));
std::string msg =
std::string(
"Can't cast unknown element type DenseIntOrFPElementsAttr (") +
- nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
- throw nb::type_error(msg.c_str());
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(pyAttribute))) +
+ ")";
+ throw nanobind::type_error(msg.c_str());
}
-nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
+nanobind::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
if (PyBoolAttribute::isaFunction(pyAttribute))
- return nb::cast(PyBoolAttribute(pyAttribute));
+ return nanobind::cast(PyBoolAttribute(pyAttribute));
if (PyIntegerAttribute::isaFunction(pyAttribute))
- return nb::cast(PyIntegerAttribute(pyAttribute));
- std::string msg = std::string("Can't cast unknown attribute type Attr (") +
- nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
- ")";
- throw nb::type_error(msg.c_str());
+ return nanobind::cast(PyIntegerAttribute(pyAttribute));
+ std::string msg =
+ std::string("Can't cast unknown attribute type Attr (") +
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(pyAttribute))) +
+ ")";
+ throw nanobind::type_error(msg.c_str());
}
-nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
+nanobind::object
+symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
- return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
+ return nanobind::cast(PyFlatSymbolRefAttribute(pyAttribute));
if (PySymbolRefAttribute::isaFunction(pyAttribute))
- return nb::cast(PySymbolRefAttribute(pyAttribute));
- std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
- nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
- ")";
- throw nb::type_error(msg.c_str());
-}
-
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
-void PyStringAttribute::bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirStringAttrGet(context->get(), toMlirStringRef(value));
- return PyStringAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued string attribute");
- c.def_static(
- "get",
- [](const nb::bytes &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirStringAttrGet(context->get(), toMlirStringRef(value));
- return PyStringAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued string attribute");
- c.def_static(
- "get_typed",
- [](PyType &type, const std::string &value) {
- MlirAttribute attr =
- mlirStringAttrTypedGet(type, toMlirStringRef(value));
- return PyStringAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"),
- "Gets a uniqued string attribute associated to a type");
- c.def_prop_ro(
- "value",
- [](PyStringAttribute &self) {
- MlirStringRef stringRef = mlirStringAttrGetValue(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the string attribute");
- c.def_prop_ro(
- "value_bytes",
- [](PyStringAttribute &self) {
- MlirStringRef stringRef = mlirStringAttrGetValue(self);
- return nb::bytes(stringRef.data, stringRef.length);
- },
- "Returns the value of the string attribute as `bytes`");
-}
-
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-void populateIRAttributes(nb::module_ &m) {
- PyAffineMapAttribute::bind(m);
- PyDenseBoolArrayAttribute::bind(m);
- PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
- PyDenseI8ArrayAttribute::bind(m);
- PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
- PyDenseI16ArrayAttribute::bind(m);
- PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
- PyDenseI32ArrayAttribute::bind(m);
- PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
- PyDenseI64ArrayAttribute::bind(m);
- PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
- PyDenseF32ArrayAttribute::bind(m);
- PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
- PyDenseF64ArrayAttribute::bind(m);
- PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
- PyGlobals::get().registerTypeCaster(
- mlirDenseArrayAttrGetTypeID(),
- nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
-
- PyArrayAttribute::bind(m);
- PyArrayAttribute::PyArrayAttributeIterator::bind(m);
- PyBoolAttribute::bind(m);
- PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
- PyDenseFPElementsAttribute::bind(m);
- PyDenseIntElementsAttribute::bind(m);
- PyGlobals::get().registerTypeCaster(
- mlirDenseIntOrFPElementsAttrGetTypeID(),
- nb::cast<nb::callable>(
- nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
- PyDenseResourceElementsAttribute::bind(m);
-
- PyDictAttribute::bind(m);
- PySymbolRefAttribute::bind(m);
- PyGlobals::get().registerTypeCaster(
- mlirSymbolRefAttrGetTypeID(),
- nb::cast<nb::callable>(
- nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
-
- PyFlatSymbolRefAttribute::bind(m);
- PyOpaqueAttribute::bind(m);
- PyFloatAttribute::bind(m);
- PyIntegerAttribute::bind(m);
- PyIntegerSetAttribute::bind(m);
- PyStringAttribute::bind(m);
- PyTypeAttribute::bind(m);
- PyGlobals::get().registerTypeCaster(
- mlirIntegerAttrGetTypeID(),
- nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
- PyUnitAttribute::bind(m);
-
- PyStridedLayoutAttribute::bind(m);
+ return nanobind::cast(PySymbolRefAttribute(pyAttribute));
+ std::string msg =
+ std::string("Can't cast unknown SymbolRef attribute (") +
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(pyAttribute))) +
+ ")";
+ throw nanobind::type_error(msg.c_str());
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7350046f428c7..e4179011170d7 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -28,493 +28,14 @@ using llvm::Twine;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
- return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
- mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
-}
-
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerTypeGetTypeID;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create an unsigned integer type");
- c.def_prop_ro(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_prop_ro(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_prop_ro(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_prop_ro(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIndexTypeGetTypeID;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a index type.");
- }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
- static constexpr const char *pyClassName = "FloatType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
- "Returns the width of the floating-point type");
- }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
- : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat4E2M1FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float4E2M1FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
- return PyFloat4E2M1FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
- : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E2M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E2M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
- return PyFloat6E2M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
- : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E3M2FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E3M2FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
- return PyFloat6E3M2FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
- : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
- return PyFloat8E4M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2TypeGet(context->get());
- return PyFloat8E5M2Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3TypeGet(context->get());
- return PyFloat8E4M3Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
- : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
- return PyFloat8E4M3FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
- : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3B11FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
- return PyFloat8E4M3B11FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
- : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
- return PyFloat8E5M2FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E3M4TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E3M4Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E3M4TypeGet(context->get());
- return PyFloat8E3M4Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
- : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E8M0FNUTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E8M0FNUType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
- return PyFloat8E8M0FNUType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
- }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirBFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a bf16 type.");
- }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f16 type.");
- }
-};
-
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatTF32TypeGetTypeID;
- static constexpr const char *pyClassName = "FloatTF32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirTF32TypeGet(context->get());
- return PyTF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a tf32 type.");
- }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat32TypeGetTypeID;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f32 type.");
- }
-};
-
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat64TypeGetTypeID;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f64 type.");
- }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirNoneTypeGetTypeID;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirNoneTypeGet(context->get());
- return PyNoneType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a none type.");
- }
-};
-
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirComplexTypeGetTypeID;
- static constexpr const char *pyClassName = "ComplexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType) {
- // The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType)) {
- MlirType t = mlirComplexTypeGet(elementType);
- return PyComplexType(elementType.getContext(), t);
- }
- throw nb::value_error(
- (Twine("invalid '") +
- nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
- "' and expected floating point or integer type.")
- .str()
- .c_str());
- },
- "Create a complex type");
- c.def_prop_ro(
- "element_type",
- [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
- .maybeDownCast();
- },
- "Returns element type.");
+// Shaped Type Interface - ShapedType
+void PyShapedType::requireHasRank() {
+ if (!mlirShapedTypeHasRank(*this)) {
+ throw nb::value_error(
+ "calling this method requires that the type has a rank.");
}
-};
-
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+}
-// Shaped Type Interface - ShapedType
void PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
@@ -620,535 +141,633 @@ void PyShapedType::bindDerived(ClassTy &c) {
"shaped types.");
}
-void PyShapedType::requireHasRank() {
- if (!mlirShapedTypeHasRank(*this)) {
- throw nb::value_error(
- "calling this method requires that the type has a rank.");
- }
+void PyIntegerType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_signless",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+ "Create a signless integer type");
+ c.def_static(
+ "get_signed",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+ "Create a signed integer type");
+ c.def_static(
+ "get_unsigned",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+ "Create an unsigned integer type");
+ c.def_prop_ro(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+ "Returns the width of the integer type");
+ c.def_prop_ro(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self);
+ },
+ "Returns whether this is a signless integer");
+ c.def_prop_ro(
+ "is_signed",
+ [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); },
+ "Returns whether this is a signed integer");
+ c.def_prop_ro(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self);
+ },
+ "Returns whether this is an unsigned integer");
}
-const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
+void PyIndexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirIndexTypeGet(context->get());
+ return PyIndexType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a index type.");
+}
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void PyFloatType::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+}
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirVectorTypeGetTypeID;
- static constexpr const char *pyClassName = "VectorType";
- using PyConcreteType::PyConcreteType;
+void PyFloat4E2M1FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+ return PyFloat4E2M1FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float4_e2m1fn type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a vector type")
- .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("context") = nb::none(), "Create a vector type")
- .def_prop_ro(
- "scalable",
- [](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_prop_ro("scalable_dims", [](MlirType self) {
- std::vector<bool> scalableDims;
- size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
- scalableDims.reserve(rank);
- for (size_t i = 0; i < rank; ++i)
- scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
- return scalableDims;
- });
- }
+void PyFloat6E2M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+ return PyFloat6E2M3FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float6_e2m3fn type.");
+}
-private:
- static PyVectorType
- getChecked(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
- }
+void PyFloat6E3M2FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+ return PyFloat6E3M2FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float6_e3m2fn type.");
+}
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
+void PyFloat8E4M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+ return PyFloat8E4M3FNType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3fn type.");
+}
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else {
- type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- }
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
+void PyFloat8E5M2Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2TypeGet(context->get());
+ return PyFloat8E5M2Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e5m2 type.");
+}
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyMlirContext context) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
- }
+void PyFloat8E4M3Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3TypeGet(context->get());
+ return PyFloat8E4M3Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3 type.");
+}
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
+void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+ return PyFloat8E4M3FNUZType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3fnuz type.");
+}
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else {
- type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
- }
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
-};
+void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+ return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e4m3b11fnuz type.");
+}
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
- : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirRankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "RankedTensorType";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+ return PyFloat8E5M2FNUZType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e5m2fnuz type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirRankedTensorTypeGetChecked(
- loc, shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
- "Create a ranked tensor type");
- c.def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirRankedTensorTypeGet(
- shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
- "Create a ranked tensor type");
- c.def_prop_ro(
- "encoding",
- [](PyRankedTensorType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
- if (mlirAttributeIsNull(encoding))
- return std::nullopt;
- return PyAttribute(self.getContext(), encoding).maybeDownCast();
- });
- }
-};
+void PyFloat8E3M4Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E3M4TypeGet(context->get());
+ return PyFloat8E3M4Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e3m4 type.");
+}
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
- : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedTensorType";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+ return PyFloat8E8M0FNUType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(),
+ "Create a float8_e8m0fnu type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("loc") = nb::none(),
- "Create a unranked tensor type");
- c.def_static(
- "get_unchecked",
- [](PyType &elementType, DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirUnrankedTensorTypeGet(elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("context") = nb::none(),
- "Create a unranked tensor type");
- }
-};
+void PyBF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirBF16TypeGet(context->get());
+ return PyBF16Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a bf16 type.");
+}
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "MemRefType";
- using PyConcreteType::PyConcreteType;
+void PyF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF16TypeGet(context->get());
+ return PyF16Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a f16 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
- shape.data(), layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a memref type")
- .def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute layoutAttr =
- layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
- layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(),
- nb::arg("memory_space") = nb::none(),
- nb::arg("context") = nb::none(), "Create a memref type")
- .def_prop_ro(
- "layout",
- [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
- return PyAttribute(self.getContext(),
- mlirMemRefTypeGetLayout(self))
- .maybeDownCast();
- },
- "The layout of the MemRef type.")
- .def(
- "get_strides_and_offset",
- [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
- std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
- int64_t offset;
- if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
- self, strides.data(), &offset)))
- throw std::runtime_error(
- "Failed to extract strides and offset from memref.");
- return {strides, offset};
- },
- "The strides and offset of the MemRef type.")
- .def_prop_ro(
- "affine_map",
- [](PyMemRefType &self) -> PyAffineMap {
- MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
- return PyAffineMap(self.getContext(), map);
- },
- "The layout of the MemRef type as an affine map.")
- .def_prop_ro(
- "memory_space",
- [](PyMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given MemRef type.");
- }
-};
+void PyTF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirTF32TypeGet(context->get());
+ return PyTF32Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a tf32 type.");
+}
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
- : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedMemRefType";
- using PyConcreteType::PyConcreteType;
+void PyF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF32TypeGet(context->get());
+ return PyF32Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a f32 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyF64Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF64TypeGet(context->get());
+ return PyF64Type(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a f64 type.");
+}
- MlirType t =
- mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("loc") = nb::none(), "Create a unranked memref type")
- .def_static(
- "get_unchecked",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyNoneType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirNoneTypeGet(context->get());
+ return PyNoneType(context->getRef(), t);
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a none type.");
+}
- MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("context") = nb::none(), "Create a unranked memref type")
- .def_prop_ro(
- "memory_space",
- [](PyUnrankedMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given Unranked MemRef type.");
- }
-};
+void PyComplexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ MlirType t = mlirComplexTypeGet(elementType);
+ return PyComplexType(elementType.getContext(), t);
+ }
+ throw nanobind::value_error(
+ (Twine("invalid '") +
+ nanobind::cast<std::string>(
+ nanobind::repr(nanobind::cast(elementType))) +
+ "' and expected floating point or integer type.")
+ .str()
+ .c_str());
+ },
+ "Create a complex type");
+ c.def_prop_ro(
+ "element_type",
+ [](PyComplexType &self) -> nanobind::typed<nanobind::object, PyType> {
+ return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+ .maybeDownCast();
+ },
+ "Returns element type.");
+}
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTupleTypeGetTypeID;
- static constexpr const char *pyClassName = "TupleType";
- using PyConcreteType::PyConcreteType;
+void PyVectorType::bindDerived(ClassTy &c) {
+ c.def_static("get", &PyVectorType::getChecked, nanobind::arg("shape"),
+ nanobind::arg("element_type"), nanobind::kw_only(),
+ nanobind::arg("scalable") = nanobind::none(),
+ nanobind::arg("scalable_dims") = nanobind::none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a vector type")
+ .def_static("get_unchecked", &PyVectorType::get, nanobind::arg("shape"),
+ nanobind::arg("element_type"), nanobind::kw_only(),
+ nanobind::arg("scalable") = nanobind::none(),
+ nanobind::arg("scalable_dims") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a vector type")
+ .def_prop_ro("scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_prop_ro("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_tuple",
- [](const std::vector<PyType> &elements,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirElements;
- mlirElements.reserve(elements.size());
- for (const auto &element : elements)
- mlirElements.push_back(element.get());
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- mlirElements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- "Create a tuple type");
- c.def_static(
- "get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- elements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
- // clang-format on
- "Create a tuple type");
- c.def(
- "get_type",
- [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
- .maybeDownCast();
- },
- nb::arg("pos"), "Returns the pos-th type in the tuple type.");
- c.def_prop_ro(
- "num_types",
- [](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self);
- },
- "Returns the number of types contained in a tuple.");
- }
-};
+void PyRankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirRankedTensorTypeGetChecked(
+ loc, shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("encoding") = nanobind::none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a ranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("encoding") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a ranked tensor type");
+ c.def_prop_ro(
+ "encoding",
+ [](PyRankedTensorType &self)
+ -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> {
+ MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+ if (mlirAttributeIsNull(encoding))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), encoding).maybeDownCast();
+ });
+}
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFunctionTypeGetTypeID;
- static constexpr const char *pyClassName = "FunctionType";
- using PyConcreteType::PyConcreteType;
+void PyUnrankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"), nanobind::arg("loc") = nanobind::none(),
+ "Create a unranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a unranked tensor type");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirInputs;
- mlirInputs.reserve(inputs.size());
- for (const auto &input : inputs)
- mlirInputs.push_back(input.get());
- std::vector<MlirType> mlirResults;
- mlirResults.reserve(results.size());
- for (const auto &result : results)
- mlirResults.push_back(result.get());
+void PyMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout,
+ PyAttribute *memorySpace, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+ shape.data(), layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("layout") = nanobind::none(),
+ nanobind::arg("memory_space") = nanobind::none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a memref type")
+ .def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("shape"), nanobind::arg("element_type"),
+ nanobind::arg("layout") = nanobind::none(),
+ nanobind::arg("memory_space") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(), "Create a memref type")
+ .def_prop_ro(
+ "layout",
+ [](PyMemRefType &self)
+ -> nanobind::typed<nanobind::object, PyAttribute> {
+ return PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
+ .maybeDownCast();
+ },
+ "The layout of the MemRef type.")
+ .def(
+ "get_strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+ self, strides.data(), &offset)))
+ throw std::runtime_error(
+ "Failed to extract strides and offset from memref.");
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
+ .def_prop_ro(
+ "affine_map",
+ [](PyMemRefType &self) -> PyAffineMap {
+ MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+ return PyAffineMap(self.getContext(), map);
+ },
+ "The layout of the MemRef type as an affine map.")
+ .def_prop_ro(
+ "memory_space",
+ [](PyMemRefType &self)
+ -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> {
+ MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given MemRef type.");
+}
- MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
- mlirInputs.data(), results.size(),
- mlirResults.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- "Gets a FunctionType from a list of input and result types");
- c.def_static(
- "get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
- DefaultingPyMlirContext context) {
- MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
- // clang-format on
- "Gets a FunctionType from a list of input and result types");
- c.def_prop_ro(
- "inputs",
- [](PyFunctionType &self) {
- MlirType t = self;
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetInput(t, i));
- }
- return types;
- },
- "Returns the list of input types in the FunctionType.");
- c.def_prop_ro(
- "results",
- [](PyFunctionType &self) {
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetResult(self, i));
- }
- return types;
- },
- "Returns the list of result types in the FunctionType.");
- }
-};
+void PyUnrankedMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t =
+ mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"), nanobind::arg("memory_space").none(),
+ nanobind::arg("loc") = nanobind::none(), "Create a unranked memref type")
+ .def_static(
+ "get_unchecked",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nanobind::arg("element_type"), nanobind::arg("memory_space").none(),
+ nanobind::arg("context") = nanobind::none(),
+ "Create a unranked memref type")
+ .def_prop_ro(
+ "memory_space",
+ [](PyUnrankedMemRefType &self)
+ -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> {
+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given Unranked MemRef type.");
+}
+
+void PyTupleType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_tuple",
+ [](const std::vector<PyType> &elements, DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirElements;
+ mlirElements.reserve(elements.size());
+ for (const auto &element : elements)
+ mlirElements.push_back(element.get());
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ mlirElements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(),
+ "Create a tuple type");
+ c.def_static(
+ "get_tuple",
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(),
+ // clang-format off
+ nanobind::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
+ // clang-format on
+ "Create a tuple type");
+ c.def(
+ "get_type",
+ [](PyTupleType &self,
+ intptr_t pos) -> nanobind::typed<nanobind::object, PyType> {
+ return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+ .maybeDownCast();
+ },
+ nanobind::arg("pos"), "Returns the pos-th type in the tuple type.");
+ c.def_prop_ro(
+ "num_types",
+ [](PyTupleType &self) -> intptr_t {
+ return mlirTupleTypeGetNumTypes(self);
+ },
+ "Returns the number of types contained in a tuple.");
+}
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueTypeGetTypeID;
- static constexpr const char *pyClassName = "OpaqueType";
- using PyConcreteType::PyConcreteType;
+void PyFunctionType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<PyType> inputs, std::vector<PyType> results,
+ DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirInputs;
+ mlirInputs.reserve(inputs.size());
+ for (const auto &input : inputs)
+ mlirInputs.push_back(input.get());
+ std::vector<MlirType> mlirResults;
+ mlirResults.reserve(results.size());
+ for (const auto &result : results)
+ mlirResults.push_back(result.get());
+
+ MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
+ mlirInputs.data(), results.size(),
+ mlirResults.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nanobind::arg("inputs"), nanobind::arg("results"),
+ nanobind::arg("context") = nanobind::none(),
+ "Gets a FunctionType from a list of input and result types");
+ c.def_static(
+ "get",
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nanobind::arg("inputs"), nanobind::arg("results"),
+ nanobind::arg("context") = nanobind::none(),
+ // clang-format off
+ nanobind::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
+ // clang-format on
+ "Gets a FunctionType from a list of input and result types");
+ c.def_prop_ro(
+ "inputs",
+ [](PyFunctionType &self) {
+ MlirType t = self;
+ nanobind::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetInput(t, i));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ c.def_prop_ro(
+ "results",
+ [](PyFunctionType &self) {
+ nanobind::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetResult(self, i));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const std::string &typeData,
- DefaultingPyMlirContext context) {
- MlirType type = mlirOpaqueTypeGet(context->get(),
- toMlirStringRef(dialectNamespace),
- toMlirStringRef(typeData));
- return PyOpaqueType(context->getRef(), type);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"),
- nb::arg("context") = nb::none(),
- "Create an unregistered (opaque) dialect type.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque type as a string.");
- c.def_prop_ro(
- "data",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaque type as a string.");
- }
-};
+void PyOpaqueType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const std::string &typeData,
+ DefaultingPyMlirContext context) {
+ MlirType type =
+ mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace),
+ toMlirStringRef(typeData));
+ return PyOpaqueType(context->getRef(), type);
+ },
+ nanobind::arg("dialect_namespace"), nanobind::arg("buffer"),
+ nanobind::arg("context") = nanobind::none(),
+ "Create an unregistered (opaque) dialect type.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+ return nanobind::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque type as a string.");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+ return nanobind::str(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaque type as a string.");
+}
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 9790a8feb8d03..ac488c3494008 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -9,7 +9,9 @@
#include "Pass.h"
#include "Rewrite.h"
#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
@@ -506,8 +508,616 @@ void PyOpAttributeMap::bind(nanobind::module_ &m) {
"Returns a list of `(name, attribute)` tuples.");
}
+//------------------------------------------------------------------------------
+// Populates the core attributes of the 'ir' submodule.
+//------------------------------------------------------------------------------
+
+void PyAffineMapAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyAffineMap &affineMap) {
+ MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+ return PyAffineMapAttribute(affineMap.getContext(), attr);
+ },
+ nanobind::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+ c.def_prop_ro(
+ "value",
+ [](PyAffineMapAttribute &self) {
+ return PyAffineMap(self.getContext(), mlirAffineMapAttrGetValue(self));
+ },
+ "Returns the value of the AffineMap attribute");
+}
+
+void PyIntegerSetAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyIntegerSet &integerSet) {
+ MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+ return PyIntegerSetAttribute(integerSet.getContext(), attr);
+ },
+ nanobind::arg("integer_set"),
+ "Gets an attribute wrapping an IntegerSet.");
+}
+
+void PyArrayAttribute::PyArrayAttributeIterator::bind(nanobind::module_ &m) {
+ nanobind::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+ .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+ .def("__next__", &PyArrayAttributeIterator::dunderNext);
+}
+
+void PyArrayAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const nanobind::list &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(nanobind::len(attributes));
+ for (auto attribute : attributes) {
+ mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
+ }
+ MlirAttribute attr = mlirArrayAttrGet(
+ context->get(), mlirAttributes.size(), mlirAttributes.data());
+ return PyArrayAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("attributes"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued Array attribute");
+ c.def("__getitem__",
+ [](PyArrayAttribute &arr,
+ intptr_t i) -> nanobind::typed<nanobind::object, PyAttribute> {
+ if (i >= mlirArrayAttrGetNumElements(arr))
+ throw nanobind::index_error("ArrayAttribute index out of range");
+ return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
+ })
+ .def("__len__",
+ [](const PyArrayAttribute &arr) {
+ return mlirArrayAttrGetNumElements(arr);
+ })
+ .def("__iter__", [](const PyArrayAttribute &arr) {
+ return PyArrayAttributeIterator(arr);
+ });
+ c.def("__add__", [](PyArrayAttribute arr, const nanobind::list &extras) {
+ std::vector<MlirAttribute> attributes;
+ intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+ attributes.reserve(numOldElements + nanobind::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ attributes.push_back(arr.getItem(i));
+ for (nanobind::handle attr : extras)
+ attributes.push_back(pyTryCast<PyAttribute>(attr));
+ MlirAttribute arrayAttr = mlirArrayAttrGet(
+ arr.getContext()->get(), attributes.size(), attributes.data());
+ return PyArrayAttribute(arr.getContext(), arrayAttr);
+ });
+}
+
+void PyFloatAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, double value, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nanobind::arg("type"), nanobind::arg("value"),
+ nanobind::arg("loc") = nanobind::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &type, double value, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nanobind::arg("type"), nanobind::arg("value"),
+ nanobind::arg("context") = nanobind::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_f32",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF32TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("value"), nanobind::arg("context") = nanobind::none(),
+ "Gets an uniqued float point attribute associated to a f32 type");
+ c.def_static(
+ "get_f64",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF64TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("value"), nanobind::arg("context") = nanobind::none(),
+ "Gets an uniqued float point attribute associated to a f64 type");
+ c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
+ "Returns the value of the float attribute");
+ c.def("__float__", mlirFloatAttrGetValueDouble,
+ "Converts the value of the float attribute to a Python float");
+}
+
+void PyIntegerAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, int64_t value) {
+ MlirAttribute attr = mlirIntegerAttrGet(type, value);
+ return PyIntegerAttribute(type.getContext(), attr);
+ },
+ nanobind::arg("type"), nanobind::arg("value"),
+ "Gets an uniqued integer attribute associated to a type");
+ c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute");
+ c.def("__int__", toPyInt,
+ "Converts the value of the integer attribute to a Python int");
+ c.def_prop_ro_static(
+ "static_typeid",
+ [](nanobind::object & /*class*/) {
+ return PyTypeID(mlirIntegerAttrGetTypeID());
+ },
+ nanobind::sig("def static_typeid(/) -> TypeID"));
+}
+
+void PyBoolAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](bool value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+ return PyBoolAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("value"), nanobind::arg("context") = nanobind::none(),
+ "Gets an uniqued bool attribute");
+ c.def_prop_ro("value", mlirBoolAttrGetValue,
+ "Returns the value of the bool attribute");
+ c.def("__bool__", mlirBoolAttrGetValue,
+ "Converts the value of the bool attribute to a Python bool");
+}
+
+void PySymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::vector<std::string> &symbols,
+ DefaultingPyMlirContext context) {
+ return PySymbolRefAttribute::fromList(symbols, context.resolve());
+ },
+ nanobind::arg("symbols"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued SymbolRef attribute from a list of symbol names");
+ c.def_prop_ro(
+ "value",
+ [](PySymbolRefAttribute &self) {
+ std::vector<std::string> symbols = {
+ unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
+ for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); ++i)
+ symbols.push_back(
+ unwrap(mlirSymbolRefAttrGetRootReference(
+ mlirSymbolRefAttrGetNestedReference(self, i)))
+ .str());
+ return symbols;
+ },
+ "Returns the value of the SymbolRef attribute as a list[str]");
+}
+
+void PyFlatSymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+ return PyFlatSymbolRefAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("value"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued FlatSymbolRef attribute");
+ c.def_prop_ro(
+ "value",
+ [](PyFlatSymbolRefAttribute &self) {
+ MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+ return nanobind::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the FlatSymbolRef attribute as a string");
+}
+
+void PyOpaqueAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const nb_buffer &buffer,
+ PyType &type, DefaultingPyMlirContext context) {
+ const nb_buffer_info bufferInfo = buffer.request();
+ intptr_t bufferSize = bufferInfo.size;
+ MlirAttribute attr = mlirOpaqueAttrGet(
+ context->get(), toMlirStringRef(dialectNamespace), bufferSize,
+ static_cast<char *>(bufferInfo.ptr), type);
+ return PyOpaqueAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("dialect_namespace"), nanobind::arg("buffer"),
+ nanobind::arg("type"), nanobind::arg("context") = nanobind::none(),
+ // clang-format off
+ nanobind::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
+ // clang-format on
+ "Gets an Opaque attribute.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
+ return nanobind::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque attribute as a string");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
+ return nanobind::bytes(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaqued attributes as `bytes`");
+}
+
+static const char kDenseElementsAttrGetDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python buffer or array.
+
+When `type` is not provided, then some limited type inferencing is done based
+on the buffer format. Support presently exists for 8/16/32/64 signed and
+unsigned integers and float16/float32/float64. DenseElementsAttrs of these
+types can also be converted back to a corresponding buffer.
+
+For conversions outside of these types, a `type=` must be explicitly provided
+and the buffer contents must be bit-castable to the MLIR internal
+representation:
+
+ * Integer types (except for i1): the buffer must be byte aligned to the
+ next byte boundary.
+ * Floating point types: Must be bit-castable to the given floating point
+ size.
+ * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
+ row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
+ this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
+
+If a single element buffer is passed (or for i1, a single byte with value 0
+or 255), then a splat will be created.
+
+Args:
+ array: The array or buffer to convert.
+ signless: If inferring an appropriate MLIR type, use signless types for
+ integers (defaults True).
+ type: Skips inference of the MLIR element type and uses this instead. The
+ storage size must be consistent with the actual contents of the buffer.
+ shape: Overrides the shape of the buffer when constructing the MLIR
+ shaped type. This is needed when the physical and logical shape differ (as
+ for i1).
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the buffer or array cannot be matched to an MLIR
+ type or if the buffer does not meet expectations.
+)";
+
+static const char kDenseElementsAttrGetFromListDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python list of attributes.
+
+Note that it can be expensive to construct attributes individually.
+For a large number of elements, consider using a Python buffer or array instead.
+
+Args:
+ attrs: A list of attributes.
+ type: The desired shape and type of the resulting DenseElementsAttr.
+ If not provided, the element type is determined based on the type
+ of the 0th attribute and the shape is `[len(attrs)]`.
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the attributes does not match the type
+ specified by `shaped_type`.
+)";
+
+static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
+ R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
+
+This function does minimal validation or massaging of the data, and it is
+up to the caller to ensure that the buffer meets the characteristics
+implied by the shape.
+
+The backing buffer and any user objects will be retained for the lifetime
+of the resource blob. This is typically bounded to the context but the
+resource can have a shorter lifespan depending on how it is used in
+subsequent processing.
+
+Args:
+ buffer: The array or buffer to convert.
+ name: Name to provide to the resource (may be changed upon collision).
+ type: The explicit ShapedType to construct the attribute with.
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseResourceElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the buffer or array cannot be matched to an MLIR
+ type or if the buffer does not meet expectations.
+)";
+
+void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
+#if PY_VERSION_HEX < 0x03090000
+ PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
+ tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
+ tp->tp_as_buffer->bf_releasebuffer =
+ PyDenseElementsAttribute::bf_releasebuffer;
+#endif
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+ .def_static(
+ "get", PyDenseElementsAttribute::getFromBuffer,
+ nanobind::arg("array"), nanobind::arg("signless") = true,
+ nanobind::arg("type") = nanobind::none(),
+ nanobind::arg("shape") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(),
+ // clang-format off
+ nanobind::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
+ // clang-format on
+ kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ nanobind::arg("attrs"),
+ nanobind::arg("type") = nanobind::none(),
+ nanobind::arg("context") = nanobind::none(),
+ kDenseElementsAttrGetFromListDocstring)
+ .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+ nanobind::arg("shaped_type"), nanobind::arg("element_attr"),
+ "Gets a DenseElementsAttr where all values are the same")
+ .def_prop_ro("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self);
+ })
+ .def("get_splat_value",
+ [](PyDenseElementsAttribute &self)
+ -> nanobind::typed<nanobind::object, PyAttribute> {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nanobind::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self))
+ .maybeDownCast();
+ });
+}
+
+void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+}
+
+void PyDenseResourceElementsAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
+ nanobind::arg("array"), nanobind::arg("name"), nanobind::arg("type"),
+ nanobind::arg("alignment") = nanobind::none(),
+ nanobind::arg("is_mutable") = false,
+ nanobind::arg("context") = nanobind::none(),
+ // clang-format off
+ nanobind::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
+ // clang-format on
+ kDenseResourceElementsAttrGetFromBufferDocstring);
+}
+
+void PyDictAttribute::bindDerived(ClassTy &c) {
+ c.def("__contains__", &PyDictAttribute::dunderContains);
+ c.def("__len__", &PyDictAttribute::dunderLen);
+ c.def_static(
+ "get",
+ [](const nanobind::dict &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+ mlirNamedAttributes.reserve(attributes.size());
+ for (std::pair<nanobind::handle, nanobind::handle> it : attributes) {
+ auto &mlirAttr = nanobind::cast<PyAttribute &>(it.second);
+ auto name = nanobind::cast<std::string>(it.first);
+ mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+ mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
+ toMlirStringRef(name)),
+ mlirAttr));
+ }
+ MlirAttribute attr =
+ mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+ mlirNamedAttributes.data());
+ return PyDictAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("value") = nanobind::dict(),
+ nanobind::arg("context") = nanobind::none(),
+ "Gets an uniqued dict attribute");
+ c.def("__getitem__",
+ [](PyDictAttribute &self, const std::string &name)
+ -> nanobind::typed<nanobind::object, PyAttribute> {
+ MlirAttribute attr =
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr))
+ throw nanobind::key_error(
+ "attempt to access a non-existent attribute");
+ return PyAttribute(self.getContext(), attr).maybeDownCast();
+ });
+ c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+ if (index < 0 || index >= self.dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data));
+ });
+}
+
+void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+}
+
+void PyTypeAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirTypeAttrGet(value.get());
+ return PyTypeAttribute(context->getRef(), attr);
+ },
+ nanobind::arg("value"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued Type attribute");
+ c.def_prop_ro(
+ "value",
+ [](PyTypeAttribute &self) -> nanobind::typed<nanobind::object, PyType> {
+ return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+ .maybeDownCast();
+ });
+}
+
+void PyUnitAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return PyUnitAttribute(context->getRef(),
+ mlirUnitAttrGet(context->get()));
+ },
+ nanobind::arg("context") = nanobind::none(), "Create a Unit attribute.");
+}
+
+void PyStridedLayoutAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int64_t offset, const std::vector<int64_t> &strides,
+ DefaultingPyMlirContext ctx) {
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), offset, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nanobind::arg("offset"), nanobind::arg("strides"),
+ nanobind::arg("context") = nanobind::none(),
+ "Gets a strided layout attribute.");
+ c.def_static(
+ "get_fully_dynamic",
+ [](int64_t rank, DefaultingPyMlirContext ctx) {
+ auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
+ std::vector<int64_t> strides(rank);
+ llvm::fill(strides, dynamic);
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), dynamic, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nanobind::arg("rank"), nanobind::arg("context") = nanobind::none(),
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
+ "given rank.");
+ c.def_prop_ro(
+ "offset",
+ [](PyStridedLayoutAttribute &self) {
+ return mlirStridedLayoutAttrGetOffset(self);
+ },
+ "Returns the value of the float point attribute");
+ c.def_prop_ro(
+ "strides",
+ [](PyStridedLayoutAttribute &self) {
+ intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+ std::vector<int64_t> strides(size);
+ for (intptr_t i = 0; i < size; i++) {
+ strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+ }
+ return strides;
+ },
+ "Returns the value of the float point attribute");
+}
+
+void PyStringAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirStringAttrGet(context->get(), toMlirStringRef(value));
+ return PyStringAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets a uniqued string attribute");
+ c.def_static(
+ "get",
+ [](const nb::bytes &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirStringAttrGet(context->get(), toMlirStringRef(value));
+ return PyStringAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets a uniqued string attribute");
+ c.def_static(
+ "get_typed",
+ [](PyType &type, const std::string &value) {
+ MlirAttribute attr =
+ mlirStringAttrTypedGet(type, toMlirStringRef(value));
+ return PyStringAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"),
+ "Gets a uniqued string attribute associated to a type");
+ c.def_prop_ro(
+ "value",
+ [](PyStringAttribute &self) {
+ MlirStringRef stringRef = mlirStringAttrGetValue(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the string attribute");
+ c.def_prop_ro(
+ "value_bytes",
+ [](PyStringAttribute &self) {
+ MlirStringRef stringRef = mlirStringAttrGetValue(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the string attribute as `bytes`");
+}
+
+void populateIRAttributes(nb::module_ &m) {
+ PyAffineMapAttribute::bind(m);
+ PyDenseBoolArrayAttribute::bind(m);
+ PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
+ PyDenseI8ArrayAttribute::bind(m);
+ PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
+ PyDenseI16ArrayAttribute::bind(m);
+ PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
+ PyDenseI32ArrayAttribute::bind(m);
+ PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
+ PyDenseI64ArrayAttribute::bind(m);
+ PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
+ PyDenseF32ArrayAttribute::bind(m);
+ PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
+ PyDenseF64ArrayAttribute::bind(m);
+ PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
+ PyGlobals::get().registerTypeCaster(
+ mlirDenseArrayAttrGetTypeID(),
+ nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
+
+ PyArrayAttribute::bind(m);
+ PyArrayAttribute::PyArrayAttributeIterator::bind(m);
+ PyBoolAttribute::bind(m);
+ PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
+ PyDenseFPElementsAttribute::bind(m);
+ PyDenseIntElementsAttribute::bind(m);
+ PyGlobals::get().registerTypeCaster(
+ mlirDenseIntOrFPElementsAttrGetTypeID(),
+ nb::cast<nb::callable>(
+ nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
+ PyDenseResourceElementsAttribute::bind(m);
+
+ PyDictAttribute::bind(m);
+ PySymbolRefAttribute::bind(m);
+ PyGlobals::get().registerTypeCaster(
+ mlirSymbolRefAttrGetTypeID(),
+ nb::cast<nb::callable>(
+ nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
+
+ PyFlatSymbolRefAttribute::bind(m);
+ PyOpaqueAttribute::bind(m);
+ PyFloatAttribute::bind(m);
+ PyIntegerAttribute::bind(m);
+ PyIntegerSetAttribute::bind(m);
+ PyStringAttribute::bind(m);
+ PyTypeAttribute::bind(m);
+ PyGlobals::get().registerTypeCaster(
+ mlirIntegerAttrGetTypeID(),
+ nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
+ PyUnitAttribute::bind(m);
+
+ PyStridedLayoutAttribute::bind(m);
+}
+
void populateIRAffine(nb::module_ &m);
-void populateIRAttributes(nb::module_ &m);
void populateIRInterfaces(nb::module_ &m);
void populateIRTypes(nb::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 4a9fb127ee08c..b508a80189075 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -533,7 +533,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
SOURCES
MainModule.cpp
IRAffine.cpp
- IRAttributes.cpp
IRInterfaces.cpp
IRTypes.cpp
Pass.cpp
@@ -846,8 +845,9 @@ declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
- IRCore.cpp
Globals.cpp
+ IRAttributes.cpp
+ IRCore.cpp
)
################################################################################
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 43573cbc305fa..a296b5e814b4b 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -15,6 +15,7 @@
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -47,6 +48,49 @@ struct PyTestType
}
};
+struct PyTestIntegerRankedTensorType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
+ PyTestIntegerRankedTensorType,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, unsigned width,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return PyTestIntegerRankedTensorType(
+ ctx->getRef(),
+ mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
+ },
+ nb::arg("shape"), nb::arg("width"),
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+struct PyTestTensorValue
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
+ PyTestTensorValue> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAPythonTestTestTensorValue;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestTensorValue";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+ }
+};
+
class PyTestAttr
: public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
PyTestAttr> {
@@ -73,18 +117,18 @@ class PyTestAttr
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
- [](MlirContext context, bool load) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context,
+ bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
- mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+ mlirDialectHandleRegisterDialect(pythonTestDialect,
+ context.get()->get());
if (load) {
- mlirDialectHandleLoadDialect(pythonTestDialect, context);
+ mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get());
}
},
- nb::arg("context"), nb::arg("load") = true,
- // clang-format off
- nb::sig("def register_python_test_dialect(context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none(), nb::arg("load") = true);
m.def(
"register_dialect",
@@ -100,73 +144,16 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"test_diagnostics_with_errors_and_notes",
- [](MlirContext ctx) {
- mlir::python::CollectDiagnosticsToStringScope handler(ctx);
- mlirPythonTestEmitDiagnosticWithNote(ctx);
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
+ mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
- // clang-format off
- nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none());
PyTestAttr::bind(m);
PyTestType::bind(m);
-
- auto typeCls =
- mlir_type_subclass(m, "TestIntegerRankedTensorType",
- mlirTypeIsARankedIntegerTensor,
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("RankedTensorType"))
- .def_classmethod(
- "get",
- [](const nb::object &cls, std::vector<int64_t> shape,
- unsigned width, MlirContext ctx) {
- MlirAttribute encoding = mlirAttributeGetNull();
- return cls(mlirRankedTensorTypeGet(
- shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
- encoding));
- },
- // clang-format off
- nb::sig("def get(cls: object, shape: collections.abc.Sequence[int], width: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
- nb::arg("context").none() = nb::none());
-
- assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
- "TestIntegerRankedTensorType has no static_typeid");
-
- MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID, nb::arg("replace") = true)(
- nanobind::cpp_function([typeCls](const nb::object &mlirType) {
- return typeCls.get_class()(mlirType);
- }));
-
- auto valueCls = mlir_value_subclass(m, "TestTensorValue",
- mlirTypeIsAPythonTestTestTensorValue)
- .def("is_null", [](MlirValue &self) {
- return mlirValueIsNull(self);
- });
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID)(
- nanobind::cpp_function([valueCls](const nb::object &valueObj) {
- std::optional<nb::object> capsule =
- mlirApiObjectToCapsule(valueObj);
- assert(capsule.has_value() && "capsule is not null");
- MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
- MlirType t = mlirValueGetType(v);
- // This is hyper-specific in order to exercise/test registering a
- // value caster from cpp (but only for a single test case; see
- // testTensorValue python_test.py).
- if (mlirShapedTypeHasStaticShape(t) &&
- mlirShapedTypeGetDimSize(t, 0) == 1 &&
- mlirShapedTypeGetDimSize(t, 1) == 2 &&
- mlirShapedTypeGetDimSize(t, 2) == 3)
- return valueCls.get_class()(valueObj);
- return valueObj;
- }));
+ PyTestIntegerRankedTensorType::bind(m);
+ PyTestTensorValue::bind(m);
}
>From 77875ff8160e8515883280859e3b806008ff7ec9 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 29 Dec 2025 11:14:00 -0800
Subject: [PATCH 2/2] [mlir][Python] port dialect extensions to use core
PyConcreteType, PyConcreteAttribute
---
mlir/lib/Bindings/Python/DialectAMDGPU.cpp | 110 ++-
mlir/lib/Bindings/Python/DialectGPU.cpp | 152 ++--
mlir/lib/Bindings/Python/DialectLLVM.cpp | 297 ++++---
mlir/lib/Bindings/Python/DialectNVGPU.cpp | 49 +-
mlir/lib/Bindings/Python/DialectPDL.cpp | 228 +++--
mlir/lib/Bindings/Python/DialectQuant.cpp | 809 ++++++++++--------
mlir/lib/Bindings/Python/DialectSMT.cpp | 89 +-
.../Bindings/Python/DialectSparseTensor.cpp | 234 ++---
mlir/lib/Bindings/Python/DialectTransform.cpp | 248 +++---
.../dialects/transform/extras/__init__.py | 11 +-
mlir/test/python/dialects/pdl_types.py | 211 ++---
11 files changed, 1406 insertions(+), 1032 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
index 26ffc0e427e41..26115c3635b7b 100644
--- a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -8,58 +8,96 @@
#include "mlir-c/Dialect/AMDGPU.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
- auto amdgpuTDMBaseType =
- mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
- mlirAMDGPUTDMBaseTypeGetTypeID);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace amdgpu {
+struct TDMBaseType : PyConcreteType<TDMBaseType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
- return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
- },
- "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
- nb::arg("element_type"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, DefaultingPyMlirContext context) {
+ return TDMBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
+ },
+ "Gets an instance of TDMBaseType in the same context",
+ nb::arg("element_type"), nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMDescriptorType = mlir_type_subclass(
- m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
- mlirAMDGPUTDMDescriptorTypeGetTypeID);
+struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMDescriptorType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMDescriptorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMDescriptorType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMDescriptorType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
- },
- "Gets an instance of TDMDescriptorType in the same context",
- nb::arg("cls"), nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return TDMDescriptorType(
+ context->getRef(),
+ mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
+ },
+ "Gets an instance of TDMDescriptorType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
- auto amdgpuTDMGatherBaseType = mlir_type_subclass(
- m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
- mlirAMDGPUTDMGatherBaseTypeGetTypeID);
+struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAAMDGPUTDMGatherBaseType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAMDGPUTDMGatherBaseTypeGetTypeID;
+ static constexpr const char *pyClassName = "TDMGatherBaseType";
+ using PyConcreteType::PyConcreteType;
- amdgpuTDMGatherBaseType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType, MlirType indexType,
- MlirContext ctx) {
- return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
- },
- "Gets an instance of TDMGatherBaseType in the same context",
- nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
- nb::arg("ctx") = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &elementType, const PyType &indexType,
+ DefaultingPyMlirContext context) {
+ return TDMGatherBaseType(
+ context->getRef(),
+ mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
+ indexType));
+ },
+ "Gets an instance of TDMGatherBaseType in the same context",
+ nb::arg("element_type"), nb::arg("index_type"),
+ nb::arg("context").none() = nb::none());
+ }
};
+static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
+ TDMBaseType::bind(m);
+ TDMDescriptorType::bind(m);
+ TDMGatherBaseType::bind(m);
+}
+} // namespace amdgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
NB_MODULE(_mlirDialectsAMDGPU, m) {
m.doc() = "MLIR AMDGPU dialect.";
- populateDialectAMDGPUSubmodule(m);
+ mlir::python::mlir::amdgpu::populateDialectAMDGPUSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 2568d535edb5a..ea3748cc88b85 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,83 +9,105 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace gpu {
// -----------------------------------------------------------------------------
-// Module initialization.
+// AsyncTokenType
// -----------------------------------------------------------------------------
-NB_MODULE(_mlirDialectsGPU, m) {
- m.doc() = "MLIR GPU Dialect";
- //===-------------------------------------------------------------------===//
- // AsyncTokenType
- //===-------------------------------------------------------------------===//
+struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
+ static constexpr const char *pyClassName = "AsyncTokenType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AsyncTokenType(context->getRef(),
+ mlirGPUAsyncTokenTypeGet(context.get()->get()));
+ },
+ "Gets an instance of AsyncTokenType in the same context",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ObjectAttr
+//===-------------------------------------------------------------------===//
+
+struct ObjectAttr : PyConcreteAttribute<ObjectAttr> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAGPUObjectAttr;
+ static constexpr const char *pyClassName = "ObjectAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
- auto mlirGPUAsyncTokenType =
- mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](MlirAttribute target, uint32_t format, const nb::bytes &object,
+ std::optional<MlirAttribute> mlirObjectProps,
+ std::optional<MlirAttribute> mlirKernelsAttr,
+ DefaultingPyMlirContext context) {
+ MlirStringRef objectStrRef = mlirStringRefCreate(
+ static_cast<char *>(const_cast<void *>(object.data())),
+ object.size());
+ return ObjectAttr(
+ context->getRef(),
+ mlirGPUObjectAttrGetWithKernels(
+ mlirAttributeGetContext(target), target, format, objectStrRef,
+ mlirObjectProps.has_value() ? *mlirObjectProps
+ : MlirAttribute{nullptr},
+ mlirKernelsAttr.has_value() ? *mlirKernelsAttr
+ : MlirAttribute{nullptr}));
+ },
+ "target"_a, "format"_a, "object"_a, "properties"_a = nb::none(),
+ "kernels"_a = nb::none(), "context"_a = nb::none(),
+ "Gets a gpu.object from parameters.");
- mlirGPUAsyncTokenType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirGPUAsyncTokenTypeGet(ctx));
- },
- "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
- nb::arg("ctx") = nb::none());
+ c.def_prop_ro("target", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetTarget(self);
+ });
+ c.def_prop_ro("format", [](MlirAttribute self) {
+ return mlirGPUObjectAttrGetFormat(self);
+ });
+ c.def_prop_ro("object", [](MlirAttribute self) {
+ MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ });
+ c.def_prop_ro("properties", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasProperties(self))
+ return nb::cast(mlirGPUObjectAttrGetProperties(self));
+ return nb::none();
+ });
+ c.def_prop_ro("kernels", [](MlirAttribute self) -> nb::object {
+ if (mlirGPUObjectAttrHasKernels(self))
+ return nb::cast(mlirGPUObjectAttrGetKernels(self));
+ return nb::none();
+ });
+ }
+};
+} // namespace gpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
- //===-------------------------------------------------------------------===//
- // ObjectAttr
- //===-------------------------------------------------------------------===//
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+NB_MODULE(_mlirDialectsGPU, m) {
+ m.doc() = "MLIR GPU Dialect";
- mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirAttribute target, uint32_t format,
- const nb::bytes &object,
- std::optional<MlirAttribute> mlirObjectProps,
- std::optional<MlirAttribute> mlirKernelsAttr) {
- MlirStringRef objectStrRef = mlirStringRefCreate(
- static_cast<char *>(const_cast<void *>(object.data())),
- object.size());
- return cls(mlirGPUObjectAttrGetWithKernels(
- mlirAttributeGetContext(target), target, format, objectStrRef,
- mlirObjectProps.has_value() ? *mlirObjectProps
- : MlirAttribute{nullptr},
- mlirKernelsAttr.has_value() ? *mlirKernelsAttr
- : MlirAttribute{nullptr}));
- },
- "cls"_a, "target"_a, "format"_a, "object"_a,
- "properties"_a = nb::none(), "kernels"_a = nb::none(),
- "Gets a gpu.object from parameters.")
- .def_property_readonly(
- "target",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
- .def_property_readonly(
- "format",
- [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
- .def_property_readonly(
- "object",
- [](MlirAttribute self) {
- MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
- return nb::bytes(stringRef.data, stringRef.length);
- })
- .def_property_readonly("properties",
- [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasProperties(self))
- return nb::cast(
- mlirGPUObjectAttrGetProperties(self));
- return nb::none();
- })
- .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
- if (mlirGPUObjectAttrHasKernels(self))
- return nb::cast(mlirGPUObjectAttrGetKernels(self));
- return nb::none();
- });
+ mlir::python::mlir::gpu::AsyncTokenType::bind(m);
+ mlir::python::mlir::gpu::ObjectAttr::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 05681cecf82b3..d4eb078c0f55c 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -13,149 +13,176 @@
#include "mlir-c/Support.h"
#include "mlir-c/Target/LLVMIR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
using namespace llvm;
using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
-
- //===--------------------------------------------------------------------===//
- // StructType
- //===--------------------------------------------------------------------===//
-
- auto llvmStructType = mlir_type_subclass(
- m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
-
- llvmStructType
- .def_classmethod(
- "get_literal",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirLocation loc) {
- CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
-
- MlirType type = mlirLLVMStructTypeLiteralGetChecked(
- loc, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "loc"_a = nb::none())
- .def_classmethod(
- "get_literal_unchecked",
- [](const nb::object &cls, const std::vector<MlirType> &elements,
- bool packed, MlirContext context) {
- CollectDiagnosticsToStringScope scope(context);
-
- MlirType type = mlirLLVMStructTypeLiteralGet(
- context, elements.size(), elements.data(), packed);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "context"_a = nb::none());
-
- llvmStructType.def_classmethod(
- "get_identified",
- [](const nb::object &cls, const std::string &name, MlirContext context) {
- return cls(mlirLLVMStructTypeIdentifiedGet(
- context, mlirStringRefCreate(name.data(), name.size())));
- },
- "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace llvm {
+//===--------------------------------------------------------------------===//
+// StructType
+//===--------------------------------------------------------------------===//
+
+struct StructType : PyConcreteType<StructType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMStructType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirLLVMStructTypeGetTypeID;
+ static constexpr const char *pyClassName = "StructType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_literal",
+ [](const std::vector<MlirType> &elements, bool packed, MlirLocation loc,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(
+ mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false, "loc"_a = nb::none(),
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_literal_unchecked",
+ [](const std::vector<MlirType> &elements, bool packed,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(context.get()->get());
+
+ MlirType type = mlirLLVMStructTypeLiteralGet(
+ context.get()->get(), elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return StructType(context->getRef(), type);
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_static(
+ "get_identified",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, nb::kw_only(), "context"_a = nb::none());
+
+ c.def_static(
+ "get_opaque",
+ [](const std::string &name, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeOpaqueGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.size())));
+ },
+ "name"_a, "context"_a = nb::none());
+
+ c.def(
+ "set_body",
+ [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+ MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+ self, elements.size(), elements.data(), packed);
+ if (!mlirLogicalResultIsSuccess(result)) {
+ throw nb::value_error(
+ "Struct body already set to different content.");
+ }
+ },
+ "elements"_a, nb::kw_only(), "packed"_a = false);
+
+ c.def_static(
+ "new_identified",
+ [](const std::string &name, const std::vector<MlirType> &elements,
+ bool packed, DefaultingPyMlirContext context) {
+ return StructType(context->getRef(),
+ mlirLLVMStructTypeIdentifiedNewGet(
+ context.get()->get(),
+ mlirStringRefCreate(name.data(), name.length()),
+ elements.size(), elements.data(), packed));
+ },
+ "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+ "context"_a = nb::none());
+
+ c.def_prop_ro("name", [](PyType type) -> std::optional<std::string> {
+ if (mlirLLVMStructTypeIsLiteral(type))
+ return std::nullopt;
+
+ MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+ return StringRef(stringRef.data, stringRef.length).str();
+ });
+
+ c.def_prop_ro("body", [](PyType type) -> nb::object {
+ // Don't crash in absence of a body.
+ if (mlirLLVMStructTypeIsOpaque(type))
+ return nb::none();
+
+ nb::list body;
+ for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type);
+ i < e; ++i) {
+ body.append(mlirLLVMStructTypeGetElementType(type, i));
+ }
+ return body;
+ });
+
+ c.def_prop_ro("packed",
+ [](PyType type) { return mlirLLVMStructTypeIsPacked(type); });
+
+ c.def_prop_ro("opaque",
+ [](PyType type) { return mlirLLVMStructTypeIsOpaque(type); });
+ }
+};
+
+//===--------------------------------------------------------------------===//
+// PointerType
+//===--------------------------------------------------------------------===//
+
+struct PointerType : PyConcreteType<PointerType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMPointerType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirLLVMPointerTypeGetTypeID;
+ static constexpr const char *pyClassName = "PointerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::optional<unsigned> addressSpace,
+ DefaultingPyMlirContext context) {
+ python::CollectDiagnosticsToStringScope scope(context.get()->get());
+ MlirType type = mlirLLVMPointerTypeGet(
+ context.get()->get(),
+ addressSpace.has_value() ? *addressSpace : 0);
+ if (mlirTypeIsNull(type)) {
+ throw nb::value_error(scope.takeMessage().c_str());
+ }
+ return PointerType(context->getRef(), type);
+ },
+ "address_space"_a = nb::none(), nb::kw_only(),
+ "context"_a = nb::none());
+ c.def_prop_ro("address_space", [](PyType type) {
+ return mlirLLVMPointerTypeGetAddressSpace(type);
+ });
+ }
+};
- llvmStructType.def_classmethod(
- "get_opaque",
- [](const nb::object &cls, const std::string &name, MlirContext context) {
- return cls(mlirLLVMStructTypeOpaqueGet(
- context, mlirStringRefCreate(name.data(), name.size())));
- },
- "cls"_a, "name"_a, "context"_a = nb::none());
-
- llvmStructType.def(
- "set_body",
- [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
- MlirLogicalResult result = mlirLLVMStructTypeSetBody(
- self, elements.size(), elements.data(), packed);
- if (!mlirLogicalResultIsSuccess(result)) {
- throw nb::value_error(
- "Struct body already set to different content.");
- }
- },
- "elements"_a, nb::kw_only(), "packed"_a = false);
-
- llvmStructType.def_classmethod(
- "new_identified",
- [](const nb::object &cls, const std::string &name,
- const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
- return cls(mlirLLVMStructTypeIdentifiedNewGet(
- ctx, mlirStringRefCreate(name.data(), name.length()),
- elements.size(), elements.data(), packed));
- },
- "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
- "context"_a = nb::none());
-
- llvmStructType.def_property_readonly(
- "name", [](MlirType type) -> std::optional<std::string> {
- if (mlirLLVMStructTypeIsLiteral(type))
- return std::nullopt;
-
- MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
- return StringRef(stringRef.data, stringRef.length).str();
- });
-
- llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
- // Don't crash in absence of a body.
- if (mlirLLVMStructTypeIsOpaque(type))
- return nb::none();
-
- nb::list body;
- for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
- ++i) {
- body.append(mlirLLVMStructTypeGetElementType(type, i));
- }
- return body;
- });
-
- llvmStructType.def_property_readonly(
- "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
-
- llvmStructType.def_property_readonly(
- "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
-
- //===--------------------------------------------------------------------===//
- // PointerType
- //===--------------------------------------------------------------------===//
-
- mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType,
- mlirLLVMPointerTypeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, std::optional<unsigned> addressSpace,
- MlirContext context) {
- CollectDiagnosticsToStringScope scope(context);
- MlirType type = mlirLLVMPointerTypeGet(
- context, addressSpace.has_value() ? *addressSpace : 0);
- if (mlirTypeIsNull(type)) {
- throw nb::value_error(scope.takeMessage().c_str());
- }
- return cls(type);
- },
- "cls"_a, "address_space"_a = nb::none(), nb::kw_only(),
- "context"_a = nb::none())
- .def_property_readonly("address_space", [](MlirType type) {
- return mlirLLVMPointerTypeGetAddressSpace(type);
- });
+static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
+ StructType::bind(m);
+ PointerType::bind(m);
m.def(
"translate_module_to_llvmir",
@@ -167,9 +194,13 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
// clang-format on
"module"_a, nb::rv_policy::take_ownership);
}
+} // namespace llvm
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsLLVM, m) {
m.doc() = "MLIR LLVM Dialect";
- populateDialectLLVMSubmodule(m);
+ python::mlir::llvm::populateDialectLLVMSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index 18917416412c1..6387d430abbf5 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -8,34 +8,47 @@
#include "mlir-c/Dialect/NVGPU.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
- auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
- m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace nvgpu {
+struct TensorMapDescriptorType : PyConcreteType<TensorMapDescriptorType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsANVGPUTensorMapDescriptorType;
+ static constexpr const char *pyClassName = "TensorMapDescriptorType";
+ using PyConcreteType::PyConcreteType;
- nvgpuTensorMapDescriptorType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType tensorMemrefType, int swizzle,
- int l2promo, int oobFill, int interleave, MlirContext ctx) {
- return cls(mlirNVGPUTensorMapDescriptorTypeGet(
- ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
- },
- "Gets an instance of TensorMapDescriptorType in the same context",
- nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"),
- nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"),
- nb::arg("ctx") = nb::none());
-}
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &tensorMemrefType, int swizzle, int l2promo,
+ int oobFill, int interleave, DefaultingPyMlirContext context) {
+ return TensorMapDescriptorType(
+ context->getRef(), mlirNVGPUTensorMapDescriptorTypeGet(
+ context.get()->get(), tensorMemrefType,
+ swizzle, l2promo, oobFill, interleave));
+ },
+ "Gets an instance of TensorMapDescriptorType in the same context",
+ nb::arg("tensor_type"), nb::arg("swizzle"), nb::arg("l2promo"),
+ nb::arg("oob_fill"), nb::arg("interleave"),
+ nb::arg("context").none() = nb::none());
+ }
+};
+} // namespace nvgpu
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsNVGPU, m) {
m.doc() = "MLIR NVGPU dialect.";
- populateDialectNVGPUSubmodule(m);
+ mlir::python::mlir::nvgpu::TensorMapDescriptorType::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index 1acb41080f711..d2ed3b141d724 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -8,98 +8,160 @@
#include "mlir-c/Dialect/PDL.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectPDLSubmodule(const nanobind::module_ &m) {
- //===-------------------------------------------------------------------===//
- // PDLType
- //===-------------------------------------------------------------------===//
-
- auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType);
-
- //===-------------------------------------------------------------------===//
- // AttributeType
- //===-------------------------------------------------------------------===//
-
- auto attributeType =
- mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
- attributeType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLAttributeTypeGet(ctx));
- },
- "Get an instance of AttributeType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // OperationType
- //===-------------------------------------------------------------------===//
-
- auto operationType =
- mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
- operationType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLOperationTypeGet(ctx));
- },
- "Get an instance of OperationType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // RangeType
- //===-------------------------------------------------------------------===//
-
- auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
- rangeType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType elementType) {
- return cls(mlirPDLRangeTypeGet(elementType));
- },
- "Gets an instance of RangeType in the same context as the provided "
- "element type.",
- nb::arg("cls"), nb::arg("element_type"));
- rangeType.def_property_readonly(
- "element_type",
- [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
- nb::sig(
- "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
- "Get the element type.");
-
- //===-------------------------------------------------------------------===//
- // TypeType
- //===-------------------------------------------------------------------===//
-
- auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
- typeType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLTypeTypeGet(ctx));
- },
- "Get an instance of TypeType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // ValueType
- //===-------------------------------------------------------------------===//
-
- auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
- valueType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPDLValueTypeGet(ctx));
- },
- "Get an instance of TypeType in given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace pdl {
+
+//===-------------------------------------------------------------------===//
+// PDLType
+//===-------------------------------------------------------------------===//
+
+struct PDLType : PyConcreteType<PDLType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLType;
+ static constexpr const char *pyClassName = "PDLType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {}
+};
+
+//===-------------------------------------------------------------------===//
+// AttributeType
+//===-------------------------------------------------------------------===//
+
+struct AttributeType : PyConcreteType<AttributeType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType;
+ static constexpr const char *pyClassName = "AttributeType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AttributeType(context->getRef(),
+ mlirPDLAttributeTypeGet(context.get()->get()));
+ },
+ "Get an instance of AttributeType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// OperationType
+//===-------------------------------------------------------------------===//
+
+struct OperationType : PyConcreteType<OperationType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType;
+ static constexpr const char *pyClassName = "OperationType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return OperationType(context->getRef(),
+ mlirPDLOperationTypeGet(context.get()->get()));
+ },
+ "Get an instance of OperationType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// RangeType
+//===-------------------------------------------------------------------===//
+
+struct RangeType : PyConcreteType<RangeType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType;
+ static constexpr const char *pyClassName = "RangeType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ return RangeType(context->getRef(), mlirPDLRangeTypeGet(elementType));
+ },
+ "Gets an instance of RangeType in the same context as the provided "
+ "element type.",
+ nb::arg("element_type"), nb::arg("context").none() = nb::none());
+ c.def_prop_ro(
+ "element_type",
+ [](PyType &type) {
+ return PyType(type.getContext(),
+ mlirPDLRangeTypeGetElementType(type));
+ },
+ nb::sig(
+ "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
+ "Get the element type.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// TypeType
+//===-------------------------------------------------------------------===//
+
+struct TypeType : PyConcreteType<TypeType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType;
+ static constexpr const char *pyClassName = "TypeType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return TypeType(context->getRef(),
+ mlirPDLTypeTypeGet(context.get()->get()));
+ },
+ "Get an instance of TypeType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ValueType
+//===-------------------------------------------------------------------===//
+
+struct ValueType : PyConcreteType<ValueType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType;
+ static constexpr const char *pyClassName = "ValueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return ValueType(context->getRef(),
+ mlirPDLValueTypeGet(context.get()->get()));
+ },
+ "Get an instance of TypeType in given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+static void populateDialectPDLSubmodule(nanobind::module_ &m) {
+ PDLType::bind(m);
+ AttributeType::bind(m);
+ OperationType::bind(m);
+ RangeType::bind(m);
+ TypeType::bind(m);
+ ValueType::bind(m);
}
+} // namespace pdl
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsPDL, m) {
m.doc() = "MLIR PDL dialect.";
- populateDialectPDLSubmodule(m);
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::pdl::populateDialectPDLSubmodule(
+ m);
}
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index a5220fcc00604..0d60ef49f77fa 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -6,385 +6,484 @@
//
//===----------------------------------------------------------------------===//
-#include <cstdint>
#include <vector>
-#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectQuantSubmodule(const nb::module_ &m) {
- //===-------------------------------------------------------------------===//
- // QuantizedType
- //===-------------------------------------------------------------------===//
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace quant {
+//===-------------------------------------------------------------------===//
+// QuantizedType
+//===-------------------------------------------------------------------===//
- auto quantizedType =
- mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
- quantizedType.def_staticmethod(
- "default_minimum_for_integer",
- [](bool isSigned, unsigned integralWidth) {
- return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
- integralWidth);
- },
- "Default minimum value for the integer with the specified signedness and "
- "bit width.",
- nb::arg("is_signed"), nb::arg("integral_width"));
- quantizedType.def_staticmethod(
- "default_maximum_for_integer",
- [](bool isSigned, unsigned integralWidth) {
- return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
- integralWidth);
- },
- "Default maximum value for the integer with the specified signedness and "
- "bit width.",
- nb::arg("is_signed"), nb::arg("integral_width"));
- quantizedType.def_property_readonly(
- "expressed_type",
- [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
- "Type expressed by this quantized type.");
- quantizedType.def_property_readonly(
- "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
- "Flags of this quantized type (named accessors should be preferred to "
- "this)");
- quantizedType.def_property_readonly(
- "is_signed",
- [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
- "Signedness of this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type",
- [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
- "Storage type backing this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type_min",
- [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
- "The minimum value held by the storage type of this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type_max",
- [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
- "The maximum value held by the storage type of this quantized type.");
- quantizedType.def_property_readonly(
- "storage_type_integral_width",
- [](MlirType type) {
- return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
- },
- "The bitwidth of the storage type of this quantized type.");
- quantizedType.def(
- "is_compatible_expressed_type",
- [](MlirType type, MlirType candidate) {
- return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
- },
- "Checks whether the candidate type can be expressed by this quantized "
- "type.",
- nb::arg("candidate"));
- quantizedType.def_property_readonly(
- "quantized_element_type",
- [](MlirType type) {
- return mlirQuantizedTypeGetQuantizedElementType(type);
- },
- "Element type of this quantized type expressed as quantized type.");
- quantizedType.def(
- "cast_from_storage_type",
- [](MlirType type, MlirType candidate) {
- MlirType castResult =
- mlirQuantizedTypeCastFromStorageType(type, candidate);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on the storage type of this quantized type to a "
- "corresponding type based on the quantized type. Raises TypeError if the "
- "cast is not valid.",
- nb::arg("candidate"));
- quantizedType.def_staticmethod(
- "cast_to_storage_type",
- [](MlirType type) {
- MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on a quantized type to a corresponding type "
- "based on the storage type of this quantized type. Raises TypeError if "
- "the cast is not valid.",
- nb::arg("type"));
- quantizedType.def(
- "cast_from_expressed_type",
- [](MlirType type, MlirType candidate) {
- MlirType castResult =
- mlirQuantizedTypeCastFromExpressedType(type, candidate);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on the expressed type of this quantized type to "
- "a corresponding type based on the quantized type. Raises TypeError if "
- "the cast is not valid.",
- nb::arg("candidate"));
- quantizedType.def_staticmethod(
- "cast_to_expressed_type",
- [](MlirType type) {
- MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on a quantized type to a corresponding type "
- "based on the expressed type of this quantized type. Raises TypeError if "
- "the cast is not valid.",
- nb::arg("type"));
- quantizedType.def(
- "cast_expressed_to_storage_type",
- [](MlirType type, MlirType candidate) {
- MlirType castResult =
- mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
- if (!mlirTypeIsNull(castResult))
- return castResult;
- throw nb::type_error("Invalid cast.");
- },
- "Casts from a type based on the expressed type of this quantized type to "
- "a corresponding type based on the storage type. Raises TypeError if the "
- "cast is not valid.",
- nb::arg("candidate"));
+struct QuantizedType : PyConcreteType<QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAQuantizedType;
+ static constexpr const char *pyClassName = "QuantizedType";
+ using PyConcreteType::PyConcreteType;
- quantizedType.get_class().attr("FLAG_SIGNED") =
- mlirQuantizedTypeGetSignedFlag();
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "default_minimum_for_integer",
+ [](bool isSigned, unsigned integralWidth) {
+ return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
+ integralWidth);
+ },
+ "Default minimum value for the integer with the specified signedness "
+ "and "
+ "bit width.",
+ nb::arg("is_signed"), nb::arg("integral_width"));
+ c.def_static(
+ "default_maximum_for_integer",
+ [](bool isSigned, unsigned integralWidth) {
+ return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
+ integralWidth);
+ },
+ "Default maximum value for the integer with the specified signedness "
+ "and "
+ "bit width.",
+ nb::arg("is_signed"), nb::arg("integral_width"));
+ c.def_prop_ro(
+ "expressed_type",
+ [](PyType type) {
+ return PyType(type.getContext(),
+ mlirQuantizedTypeGetExpressedType(type));
+ },
+ "Type expressed by this quantized type.");
+ c.def_prop_ro(
+ "flags",
+ [](const PyType &type) { return mlirQuantizedTypeGetFlags(type); },
+ "Flags of this quantized type (named accessors should be preferred to "
+ "this)");
+ c.def_prop_ro(
+ "is_signed",
+ [](const PyType &type) { return mlirQuantizedTypeIsSigned(type); },
+ "Signedness of this quantized type.");
+ c.def_prop_ro(
+ "storage_type",
+ [](PyType type) {
+ return PyType(type.getContext(),
+ mlirQuantizedTypeGetStorageType(type));
+ },
+ "Storage type backing this quantized type.");
+ c.def_prop_ro(
+ "storage_type_min",
+ [](const PyType &type) {
+ return mlirQuantizedTypeGetStorageTypeMin(type);
+ },
+ "The minimum value held by the storage type of this quantized type.");
+ c.def_prop_ro(
+ "storage_type_max",
+ [](const PyType &type) {
+ return mlirQuantizedTypeGetStorageTypeMax(type);
+ },
+ "The maximum value held by the storage type of this quantized type.");
+ c.def_prop_ro(
+ "storage_type_integral_width",
+ [](const PyType &type) {
+ return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
+ },
+ "The bitwidth of the storage type of this quantized type.");
+ c.def(
+ "is_compatible_expressed_type",
+ [](const PyType &type, const PyType &candidate) {
+ return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
+ },
+ "Checks whether the candidate type can be expressed by this quantized "
+ "type.",
+ nb::arg("candidate"));
+ c.def_prop_ro(
+ "quantized_element_type",
+ [](PyType type) {
+ return PyType(type.getContext(),
+ mlirQuantizedTypeGetQuantizedElementType(type));
+ },
+ "Element type of this quantized type expressed as quantized type.");
+ c.def(
+ "cast_from_storage_type",
+ [](PyType type, const PyType &candidate) {
+ MlirType castResult =
+ mlirQuantizedTypeCastFromStorageType(type, candidate);
+ if (!mlirTypeIsNull(castResult))
+ return QuantizedType(type.getContext(), castResult);
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on the storage type of this quantized type to "
+ "a "
+ "corresponding type based on the quantized type. Raises TypeError if "
+ "the "
+ "cast is not valid.",
+ nb::arg("candidate"));
+ c.def_static(
+ "cast_to_storage_type",
+ [](const PyType &type) {
+ MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
+ if (!mlirTypeIsNull(castResult))
+ return castResult;
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on a quantized type to a corresponding type "
+ "based on the storage type of this quantized type. Raises TypeError if "
+ "the cast is not valid.",
+ nb::arg("type"));
+ c.def(
+ "cast_from_expressed_type",
+ [](PyType type, const PyType &candidate) {
+ MlirType castResult =
+ mlirQuantizedTypeCastFromExpressedType(type, candidate);
+ if (!mlirTypeIsNull(castResult))
+ return PyType(type.getContext(), castResult);
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on the expressed type of this quantized type "
+ "to "
+ "a corresponding type based on the quantized type. Raises TypeError if "
+ "the cast is not valid.",
+ nb::arg("candidate"));
+ c.def_static(
+ "cast_to_expressed_type",
+ [](const PyType &type) {
+ MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
+ if (!mlirTypeIsNull(castResult))
+ return castResult;
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on a quantized type to a corresponding type "
+ "based on the expressed type of this quantized type. Raises TypeError "
+ "if "
+ "the cast is not valid.",
+ nb::arg("type"));
+ c.def(
+ "cast_expressed_to_storage_type",
+ [](PyType type, const PyType &candidate) {
+ MlirType castResult =
+ mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
+ if (!mlirTypeIsNull(castResult))
+ return PyType(type.getContext(), castResult);
+ throw nb::type_error("Invalid cast.");
+ },
+ "Casts from a type based on the expressed type of this quantized type "
+ "to "
+ "a corresponding type based on the storage type. Raises TypeError if "
+ "the "
+ "cast is not valid.",
+ nb::arg("candidate"));
+ }
+};
- //===-------------------------------------------------------------------===//
- // AnyQuantizedType
- //===-------------------------------------------------------------------===//
+//===-------------------------------------------------------------------===//
+// AnyQuantizedType
+//===-------------------------------------------------------------------===//
- auto anyQuantizedType =
- mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
- quantizedType.get_class());
- anyQuantizedType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, int64_t storageTypeMin,
- int64_t storageTypeMax) {
- return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
- storageTypeMin, storageTypeMax));
- },
- "Gets an instance of AnyQuantizedType in the same context as the "
- "provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("storage_type_min"),
- nb::arg("storage_type_max"));
+struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType;
+ static constexpr const char *pyClassName = "AnyQuantizedType";
+ using PyConcreteType::PyConcreteType;
- //===-------------------------------------------------------------------===//
- // UniformQuantizedType
- //===-------------------------------------------------------------------===//
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax, DefaultingPyMlirContext context) {
+ return AnyQuantizedType(
+ context->getRef(),
+ mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
+ storageTypeMin, storageTypeMax));
+ },
+ "Gets an instance of AnyQuantizedType in the same context as the "
+ "provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("storage_type_min"), nb::arg("storage_type_max"),
+ nb::arg("context") = nb::none());
+ }
+};
- auto uniformQuantizedType = mlir_type_subclass(
- m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
- quantizedType.get_class());
- uniformQuantizedType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, double scale, int64_t zeroPoint,
- int64_t storageTypeMin, int64_t storageTypeMax) {
- return cls(mlirUniformQuantizedTypeGet(flags, storageType,
- expressedType, scale, zeroPoint,
- storageTypeMin, storageTypeMax));
- },
- "Gets an instance of UniformQuantizedType in the same context as the "
- "provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"),
- nb::arg("storage_type_min"), nb::arg("storage_type_max"));
- uniformQuantizedType.def_property_readonly(
- "scale",
- [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
- "The scale designates the difference between the real values "
- "corresponding to consecutive quantized values differing by 1.");
- uniformQuantizedType.def_property_readonly(
- "zero_point",
- [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
- "The storage value corresponding to the real value 0 in the affine "
- "equation.");
- uniformQuantizedType.def_property_readonly(
- "is_fixed_point",
- [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
- "Fixed point values are real numbers divided by a scale.");
+//===-------------------------------------------------------------------===//
+// UniformQuantizedType
+//===-------------------------------------------------------------------===//
- //===-------------------------------------------------------------------===//
- // UniformQuantizedPerAxisType
- //===-------------------------------------------------------------------===//
- auto uniformQuantizedPerAxisType = mlir_type_subclass(
- m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
- quantizedType.get_class());
- uniformQuantizedPerAxisType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, std::vector<double> scales,
- std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
- int64_t storageTypeMin, int64_t storageTypeMax) {
- if (scales.size() != zeroPoints.size())
- throw nb::value_error(
- "Mismatching number of scales and zero points.");
- auto nDims = static_cast<intptr_t>(scales.size());
- return cls(mlirUniformQuantizedPerAxisTypeGet(
- flags, storageType, expressedType, nDims, scales.data(),
- zeroPoints.data(), quantizedDimension, storageTypeMin,
- storageTypeMax));
- },
- "Gets an instance of UniformQuantizedPerAxisType in the same context as "
- "the provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
- nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
- nb::arg("storage_type_max"));
- uniformQuantizedPerAxisType.def_property_readonly(
- "scales",
- [](MlirType type) {
- intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
- std::vector<double> scales;
- scales.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
- scales.push_back(scale);
- }
- return scales;
- },
- "The scales designate the difference between the real values "
- "corresponding to consecutive quantized values differing by 1. The ith "
- "scale corresponds to the ith slice in the quantized_dimension.");
- uniformQuantizedPerAxisType.def_property_readonly(
- "zero_points",
- [](MlirType type) {
- intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
- std::vector<int64_t> zeroPoints;
- zeroPoints.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- int64_t zeroPoint =
- mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
- zeroPoints.push_back(zeroPoint);
- }
- return zeroPoints;
- },
- "the storage values corresponding to the real value 0 in the affine "
- "equation. The ith zero point corresponds to the ith slice in the "
- "quantized_dimension.");
- uniformQuantizedPerAxisType.def_property_readonly(
- "quantized_dimension",
- [](MlirType type) {
- return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
- },
- "Specifies the dimension of the shape that the scales and zero points "
- "correspond to.");
- uniformQuantizedPerAxisType.def_property_readonly(
- "is_fixed_point",
- [](MlirType type) {
- return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
- },
- "Fixed point values are real numbers divided by a scale.");
+struct UniformQuantizedType
+ : PyConcreteType<UniformQuantizedType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType;
+ static constexpr const char *pyClassName = "UniformQuantizedType";
+ using PyConcreteType::PyConcreteType;
- //===-------------------------------------------------------------------===//
- // UniformQuantizedSubChannelType
- //===-------------------------------------------------------------------===//
- auto uniformQuantizedSubChannelType = mlir_type_subclass(
- m, "UniformQuantizedSubChannelType",
- mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
- uniformQuantizedSubChannelType.def_classmethod(
- "get",
- [](const nb::object &cls, unsigned flags, MlirType storageType,
- MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
- std::vector<int32_t> quantizedDimensions,
- std::vector<int64_t> blockSizes, int64_t storageTypeMin,
- int64_t storageTypeMax) {
- return cls(mlirUniformQuantizedSubChannelTypeGet(
- flags, storageType, expressedType, scales, zeroPoints,
- static_cast<intptr_t>(blockSizes.size()),
- quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
- storageTypeMax));
- },
- "Gets an instance of UniformQuantizedSubChannel in the same context as "
- "the provided storage type.",
- nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
- nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
- nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
- nb::arg("storage_type_min"), nb::arg("storage_type_max"));
- uniformQuantizedSubChannelType.def_property_readonly(
- "quantized_dimensions",
- [](MlirType type) {
- intptr_t nDim =
- mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
- std::vector<int32_t> quantizedDimensions;
- quantizedDimensions.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- quantizedDimensions.push_back(
- mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
- }
- return quantizedDimensions;
- },
- "Gets the quantized dimensions. Each element in the returned list "
- "represents an axis of the quantized data tensor that has a specified "
- "block size. The order of elements corresponds to the order of block "
- "sizes returned by 'block_sizes' method. It means that the data tensor "
- "is quantized along the i-th dimension in the returned list using the "
- "i-th block size from block_sizes method.");
- uniformQuantizedSubChannelType.def_property_readonly(
- "block_sizes",
- [](MlirType type) {
- intptr_t nDim =
- mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
- std::vector<int64_t> blockSizes;
- blockSizes.reserve(nDim);
- for (intptr_t i = 0; i < nDim; ++i) {
- blockSizes.push_back(
- mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
- }
- return blockSizes;
- },
- "Gets the block sizes for the quantized dimensions. The i-th element in "
- "the returned list corresponds to the block size for the i-th dimension "
- "in the list returned by quantized_dimensions method.");
- uniformQuantizedSubChannelType.def_property_readonly(
- "scales",
- [](MlirType type) -> MlirAttribute {
- return mlirUniformQuantizedSubChannelTypeGetScales(type);
- },
- "The scales of the quantized type.");
- uniformQuantizedSubChannelType.def_property_readonly(
- "zero_points",
- [](MlirType type) -> MlirAttribute {
- return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
- },
- "The zero points of the quantized type.");
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, double scale, int64_t zeroPoint,
+ int64_t storageTypeMin, int64_t storageTypeMax,
+ DefaultingPyMlirContext context) {
+ return UniformQuantizedType(
+ context->getRef(),
+ mlirUniformQuantizedTypeGet(flags, storageType, expressedType,
+ scale, zeroPoint, storageTypeMin,
+ storageTypeMax));
+ },
+ "Gets an instance of UniformQuantizedType in the same context as the "
+ "provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("scale"), nb::arg("zero_point"), nb::arg("storage_type_min"),
+ nb::arg("storage_type_max"), nb::arg("context") = nb::none());
+ c.def_prop_ro(
+ "scale",
+ [](const PyType &type) {
+ return mlirUniformQuantizedTypeGetScale(type);
+ },
+ "The scale designates the difference between the real values "
+ "corresponding to consecutive quantized values differing by 1.");
+ c.def_prop_ro(
+ "zero_point",
+ [](const PyType &type) {
+ return mlirUniformQuantizedTypeGetZeroPoint(type);
+ },
+ "The storage value corresponding to the real value 0 in the affine "
+ "equation.");
+ c.def_prop_ro(
+ "is_fixed_point",
+ [](const PyType &type) {
+ return mlirUniformQuantizedTypeIsFixedPoint(type);
+ },
+ "Fixed point values are real numbers divided by a scale.");
+ }
+};
- //===-------------------------------------------------------------------===//
- // CalibratedQuantizedType
- //===-------------------------------------------------------------------===//
+//===-------------------------------------------------------------------===//
+// UniformQuantizedPerAxisType
+//===-------------------------------------------------------------------===//
- auto calibratedQuantizedType = mlir_type_subclass(
- m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
- quantizedType.get_class());
- calibratedQuantizedType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType expressedType, double min,
- double max) {
- return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
- },
- "Gets an instance of CalibratedQuantizedType in the same context as the "
- "provided expressed type.",
- nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"),
- nb::arg("max"));
- calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
- return mlirCalibratedQuantizedTypeGetMin(type);
- });
- calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
- return mlirCalibratedQuantizedTypeGetMax(type);
- });
+struct UniformQuantizedPerAxisType
+ : PyConcreteType<UniformQuantizedPerAxisType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAUniformQuantizedPerAxisType;
+ static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, std::vector<double> scales,
+ std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
+ int64_t storageTypeMin, int64_t storageTypeMax,
+ DefaultingPyMlirContext context) {
+ if (scales.size() != zeroPoints.size())
+ throw nb::value_error(
+ "Mismatching number of scales and zero points.");
+ auto nDims = static_cast<intptr_t>(scales.size());
+ return UniformQuantizedPerAxisType(
+ context->getRef(),
+ mlirUniformQuantizedPerAxisTypeGet(
+ flags, storageType, expressedType, nDims, scales.data(),
+ zeroPoints.data(), quantizedDimension, storageTypeMin,
+ storageTypeMax));
+ },
+ "Gets an instance of UniformQuantizedPerAxisType in the same context "
+ "as "
+ "the provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("scales"), nb::arg("zero_points"),
+ nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
+ nb::arg("storage_type_max"), nb::arg("context") = nb::none());
+ c.def_prop_ro(
+ "scales",
+ [](const PyType &type) {
+ intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
+ std::vector<double> scales;
+ scales.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
+ scales.push_back(scale);
+ }
+ return scales;
+ },
+ "The scales designate the difference between the real values "
+ "corresponding to consecutive quantized values differing by 1. The ith "
+ "scale corresponds to the ith slice in the quantized_dimension.");
+ c.def_prop_ro(
+ "zero_points",
+ [](const PyType &type) {
+ intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
+ std::vector<int64_t> zeroPoints;
+ zeroPoints.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ int64_t zeroPoint =
+ mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
+ zeroPoints.push_back(zeroPoint);
+ }
+ return zeroPoints;
+ },
+ "the storage values corresponding to the real value 0 in the affine "
+ "equation. The ith zero point corresponds to the ith slice in the "
+ "quantized_dimension.");
+ c.def_prop_ro(
+ "quantized_dimension",
+ [](const PyType &type) {
+ return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
+ },
+ "Specifies the dimension of the shape that the scales and zero points "
+ "correspond to.");
+ c.def_prop_ro(
+ "is_fixed_point",
+ [](const PyType &type) {
+ return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
+ },
+ "Fixed point values are real numbers divided by a scale.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===-------------------------------------------------------------------===//
+
+struct UniformQuantizedSubChannelType
+ : PyConcreteType<UniformQuantizedSubChannelType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAUniformQuantizedSubChannelType;
+ static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](unsigned flags, const PyType &storageType,
+ const PyType &expressedType, MlirAttribute scales,
+ MlirAttribute zeroPoints, std::vector<int32_t> quantizedDimensions,
+ std::vector<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax, DefaultingPyMlirContext context) {
+ return UniformQuantizedSubChannelType(
+ context->getRef(),
+ mlirUniformQuantizedSubChannelTypeGet(
+ flags, storageType, expressedType, scales, zeroPoints,
+ static_cast<intptr_t>(blockSizes.size()),
+ quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
+ storageTypeMax));
+ },
+ "Gets an instance of UniformQuantizedSubChannel in the same context as "
+ "the provided storage type.",
+ nb::arg("flags"), nb::arg("storage_type"), nb::arg("expressed_type"),
+ nb::arg("scales"), nb::arg("zero_points"),
+ nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
+ nb::arg("storage_type_min"), nb::arg("storage_type_max"),
+ nb::arg("context") = nb::none());
+ c.def_prop_ro(
+ "quantized_dimensions",
+ [](const PyType &type) {
+ intptr_t nDim =
+ mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+ std::vector<int32_t> quantizedDimensions;
+ quantizedDimensions.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ quantizedDimensions.push_back(
+ mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type,
+ i));
+ }
+ return quantizedDimensions;
+ },
+ "Gets the quantized dimensions. Each element in the returned list "
+ "represents an axis of the quantized data tensor that has a specified "
+ "block size. The order of elements corresponds to the order of block "
+ "sizes returned by 'block_sizes' method. It means that the data tensor "
+ "is quantized along the i-th dimension in the returned list using the "
+ "i-th block size from block_sizes method.");
+ c.def_prop_ro(
+ "block_sizes",
+ [](const PyType &type) {
+ intptr_t nDim =
+ mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+ std::vector<int64_t> blockSizes;
+ blockSizes.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ blockSizes.push_back(
+ mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
+ }
+ return blockSizes;
+ },
+ "Gets the block sizes for the quantized dimensions. The i-th element "
+ "in "
+ "the returned list corresponds to the block size for the i-th "
+ "dimension "
+ "in the list returned by quantized_dimensions method.");
+ c.def_prop_ro(
+ "scales",
+ [](const PyType &type) -> MlirAttribute {
+ return mlirUniformQuantizedSubChannelTypeGetScales(type);
+ },
+ "The scales of the quantized type.");
+ c.def_prop_ro(
+ "zero_points",
+ [](const PyType &type) -> MlirAttribute {
+ return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
+ },
+ "The zero points of the quantized type.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// CalibratedQuantizedType
+//===-------------------------------------------------------------------===//
+
+struct CalibratedQuantizedType
+ : PyConcreteType<CalibratedQuantizedType, QuantizedType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsACalibratedQuantizedType;
+ static constexpr const char *pyClassName = "CalibratedQuantizedType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &expressedType, double min, double max,
+ DefaultingPyMlirContext context) {
+ return CalibratedQuantizedType(
+ context->getRef(),
+ mlirCalibratedQuantizedTypeGet(expressedType, min, max));
+ },
+ "Gets an instance of CalibratedQuantizedType in the same context as "
+ "the "
+ "provided expressed type.",
+ nb::arg("expressed_type"), nb::arg("min"), nb::arg("max"),
+ nb::arg("context") = nb::none());
+ c.def_prop_ro("min", [](const PyType &type) {
+ return mlirCalibratedQuantizedTypeGetMin(type);
+ });
+ c.def_prop_ro("max", [](const PyType &type) {
+ return mlirCalibratedQuantizedTypeGetMax(type);
+ });
+ }
+};
+
+static void populateDialectQuantSubmodule(nb::module_ &m) {
+ QuantizedType::bind(m);
+
+ // Set the FLAG_SIGNED class attribute after binding QuantizedType
+ auto quantizedTypeClass = m.attr("QuantizedType");
+ quantizedTypeClass.attr("FLAG_SIGNED") = mlirQuantizedTypeGetSignedFlag();
+
+ AnyQuantizedType::bind(m);
+ UniformQuantizedType::bind(m);
+ UniformQuantizedPerAxisType::bind(m);
+ UniformQuantizedSubChannelType::bind(m);
+ CalibratedQuantizedType::bind(m);
}
+} // namespace quant
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsQuant, m) {
m.doc() = "MLIR Quantization dialect";
- populateDialectQuantSubmodule(m);
+ mlir::python::mlir::quant::populateDialectQuantSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index a87918a05b126..39490155d5216 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -13,44 +13,77 @@
#include "mlir-c/Support.h"
#include "mlir-c/Target/ExportSMTLIB.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
-
using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectSMTSubmodule(nanobind::module_ &m) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace smt {
+struct BoolType : PyConcreteType<BoolType> {
+ static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABool;
+ static constexpr const char *pyClassName = "BoolType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return BoolType(context->getRef(),
+ mlirSMTTypeGetBool(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+struct BitVectorType : PyConcreteType<BitVectorType> {
+ static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABitVector;
+ static constexpr const char *pyClassName = "BitVectorType";
+ using PyConcreteType::PyConcreteType;
- auto smtBoolType =
- mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
- .def_staticmethod(
- "get",
- [](MlirContext context) { return mlirSMTTypeGetBool(context); },
- "context"_a = nb::none());
- auto smtBitVectorType =
- mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
- .def_staticmethod(
- "get",
- [](int32_t width, MlirContext context) {
- return mlirSMTTypeGetBitVector(context, width);
- },
- "width"_a, "context"_a = nb::none());
- auto smtIntType =
- mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
- .def_staticmethod(
- "get",
- [](MlirContext context) { return mlirSMTTypeGetInt(context); },
- "context"_a = nb::none());
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int32_t width, DefaultingPyMlirContext context) {
+ return BitVectorType(
+ context->getRef(),
+ mlirSMTTypeGetBitVector(context.get()->get(), width));
+ },
+ nb::arg("width"), nb::arg("context").none() = nb::none());
+ }
+};
+
+struct IntType : PyConcreteType<IntType> {
+ static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsAInt;
+ static constexpr const char *pyClassName = "IntType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return IntType(context->getRef(),
+ mlirSMTTypeGetInt(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+static void populateDialectSMTSubmodule(nanobind::module_ &m) {
+ BoolType::bind(m);
+ BitVectorType::bind(m);
+ IntType::bind(m);
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
- mlir::python::CollectDiagnosticsToStringScope scope(
- mlirOperationGetContext(module));
+ CollectDiagnosticsToStringScope scope(mlirOperationGetContext(module));
PyPrintAccumulator printAccum;
MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
module, printAccum.getCallback(), printAccum.getUserData(),
@@ -80,9 +113,13 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
"module"_a, "inline_single_use_values"_a = false,
"indent_let_body"_a = false);
}
+} // namespace smt
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsSMT, m) {
m.doc() = "MLIR SMT Dialect";
- populateDialectSMTSubmodule(m);
+ python::mlir::smt::populateDialectSMTSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 00b65ee9745dc..9ee3fc461ef6a 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -12,15 +12,132 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/Dialect/SparseTensor.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
-using namespace mlir;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace sparse_tensor {
+
+struct EncodingAttr : PyConcreteAttribute<EncodingAttr> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsASparseTensorEncodingAttr;
+ static constexpr const char *pyClassName = "EncodingAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<MlirSparseTensorLevelType> lvlTypes,
+ std::optional<MlirAffineMap> dimToLvl,
+ std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
+ std::optional<MlirAttribute> explicitVal,
+ std::optional<MlirAttribute> implicitVal,
+ DefaultingPyMlirContext context) {
+ return EncodingAttr(
+ context->getRef(),
+ mlirSparseTensorEncodingAttrGet(
+ context.get()->get(), lvlTypes.size(), lvlTypes.data(),
+ dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
+ lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
+ crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
+ implicitVal ? *implicitVal : MlirAttribute{nullptr}));
+ },
+ nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
+ nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
+ nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(),
+ nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(),
+ "Gets a sparse_tensor.encoding from parameters.");
+
+ c.def_static(
+ "build_level_type",
+ [](MlirSparseTensorLevelFormat lvlFmt,
+ const std::vector<MlirSparseTensorLevelPropertyNondefault>
+ &properties,
+ unsigned n, unsigned m) {
+ return mlirSparseTensorEncodingAttrBuildLvlType(
+ lvlFmt, properties.data(), properties.size(), n, m);
+ },
+ nb::arg("lvl_fmt"),
+ nb::arg("properties") =
+ std::vector<MlirSparseTensorLevelPropertyNondefault>(),
+ nb::arg("n") = 0, nb::arg("m") = 0,
+ "Builds a sparse_tensor.encoding.level_type from parameters.");
+
+ c.def_prop_ro("lvl_types", [](MlirAttribute self) {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ std::vector<MlirSparseTensorLevelType> ret;
+ ret.reserve(lvlRank);
+ for (int l = 0; l < lvlRank; ++l)
+ ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
+ return ret;
+ });
+
+ c.def_prop_ro(
+ "dim_to_lvl", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+ MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
+ if (mlirAffineMapIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro(
+ "lvl_to_dim", [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+ MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
+ if (mlirAffineMapIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro("pos_width", mlirSparseTensorEncodingAttrGetPosWidth);
+ c.def_prop_ro("crd_width", mlirSparseTensorEncodingAttrGetCrdWidth);
+
+ c.def_prop_ro(
+ "explicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+ MlirAttribute ret = mlirSparseTensorEncodingAttrGetExplicitVal(self);
+ if (mlirAttributeIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro(
+ "implicit_val", [](MlirAttribute self) -> std::optional<MlirAttribute> {
+ MlirAttribute ret = mlirSparseTensorEncodingAttrGetImplicitVal(self);
+ if (mlirAttributeIsNull(ret))
+ return {};
+ return ret;
+ });
+
+ c.def_prop_ro("structured_n", [](MlirAttribute self) -> unsigned {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ return mlirSparseTensorEncodingAttrGetStructuredN(
+ mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+ });
+
+ c.def_prop_ro("structured_m", [](MlirAttribute self) -> unsigned {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ return mlirSparseTensorEncodingAttrGetStructuredM(
+ mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+ });
+
+ c.def_prop_ro("lvl_formats_enum", [](MlirAttribute self) {
+ const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+ std::vector<MlirSparseTensorLevelFormat> ret;
+ ret.reserve(lvlRank);
+ for (int l = 0; l < lvlRank; l++)
+ ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
+ return ret;
+ });
+ }
+};
+
+static void populateDialectSparseTensorSubmodule(nb::module_ &m) {
nb::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", nb::is_arithmetic(),
nb::is_flag())
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
@@ -34,115 +151,14 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
.value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE)
.value("soa", MLIR_SPARSE_PROPERTY_SOA);
- mlir_attribute_subclass(m, "EncodingAttr",
- mlirAttributeIsASparseTensorEncodingAttr)
- .def_classmethod(
- "get",
- [](const nb::object &cls,
- std::vector<MlirSparseTensorLevelType> lvlTypes,
- std::optional<MlirAffineMap> dimToLvl,
- std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
- std::optional<MlirAttribute> explicitVal,
- std::optional<MlirAttribute> implicitVal, MlirContext context) {
- return cls(mlirSparseTensorEncodingAttrGet(
- context, lvlTypes.size(), lvlTypes.data(),
- dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
- lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
- crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
- implicitVal ? *implicitVal : MlirAttribute{nullptr}));
- },
- nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
- nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
- nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(),
- nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(),
- "Gets a sparse_tensor.encoding from parameters.")
- .def_classmethod(
- "build_level_type",
- [](const nb::object &cls, MlirSparseTensorLevelFormat lvlFmt,
- const std::vector<MlirSparseTensorLevelPropertyNondefault>
- &properties,
- unsigned n, unsigned m) {
- return mlirSparseTensorEncodingAttrBuildLvlType(
- lvlFmt, properties.data(), properties.size(), n, m);
- },
- nb::arg("cls"), nb::arg("lvl_fmt"),
- nb::arg("properties") =
- std::vector<MlirSparseTensorLevelPropertyNondefault>(),
- nb::arg("n") = 0, nb::arg("m") = 0,
- "Builds a sparse_tensor.encoding.level_type from parameters.")
- .def_property_readonly(
- "lvl_types",
- [](MlirAttribute self) {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- std::vector<MlirSparseTensorLevelType> ret;
- ret.reserve(lvlRank);
- for (int l = 0; l < lvlRank; ++l)
- ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
- return ret;
- })
- .def_property_readonly(
- "dim_to_lvl",
- [](MlirAttribute self) -> std::optional<MlirAffineMap> {
- MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self);
- if (mlirAffineMapIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly(
- "lvl_to_dim",
- [](MlirAttribute self) -> std::optional<MlirAffineMap> {
- MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
- if (mlirAffineMapIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly("pos_width",
- mlirSparseTensorEncodingAttrGetPosWidth)
- .def_property_readonly("crd_width",
- mlirSparseTensorEncodingAttrGetCrdWidth)
- .def_property_readonly(
- "explicit_val",
- [](MlirAttribute self) -> std::optional<MlirAttribute> {
- MlirAttribute ret =
- mlirSparseTensorEncodingAttrGetExplicitVal(self);
- if (mlirAttributeIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly(
- "implicit_val",
- [](MlirAttribute self) -> std::optional<MlirAttribute> {
- MlirAttribute ret =
- mlirSparseTensorEncodingAttrGetImplicitVal(self);
- if (mlirAttributeIsNull(ret))
- return {};
- return ret;
- })
- .def_property_readonly(
- "structured_n",
- [](MlirAttribute self) -> unsigned {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- return mlirSparseTensorEncodingAttrGetStructuredN(
- mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
- })
- .def_property_readonly(
- "structured_m",
- [](MlirAttribute self) -> unsigned {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- return mlirSparseTensorEncodingAttrGetStructuredM(
- mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
- })
- .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) {
- const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- std::vector<MlirSparseTensorLevelFormat> ret;
- ret.reserve(lvlRank);
- for (int l = 0; l < lvlRank; l++)
- ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
- return ret;
- });
+ EncodingAttr::bind(m);
}
+} // namespace sparse_tensor
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsSparseTensor, m) {
m.doc() = "MLIR SparseTensor dialect.";
- populateDialectSparseTensorSubmodule(m);
+ mlir::python::mlir::sparse_tensor::populateDialectSparseTensorSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 150c69953d960..6ef23b8ab686e 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,112 +11,164 @@
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-static void populateDialectTransformSubmodule(const nb::module_ &m) {
- //===-------------------------------------------------------------------===//
- // AnyOpType
- //===-------------------------------------------------------------------===//
-
- auto anyOpType =
- mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
- mlirTransformAnyOpTypeGetTypeID);
- anyOpType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirTransformAnyOpTypeGet(ctx));
- },
- "Get an instance of AnyOpType in the given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // AnyParamType
- //===-------------------------------------------------------------------===//
-
- auto anyParamType =
- mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
- mlirTransformAnyParamTypeGetTypeID);
- anyParamType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirTransformAnyParamTypeGet(ctx));
- },
- "Get an instance of AnyParamType in the given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // AnyValueType
- //===-------------------------------------------------------------------===//
-
- auto anyValueType =
- mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
- mlirTransformAnyValueTypeGetTypeID);
- anyValueType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirTransformAnyValueTypeGet(ctx));
- },
- "Get an instance of AnyValueType in the given context.", nb::arg("cls"),
- nb::arg("context") = nb::none());
-
- //===-------------------------------------------------------------------===//
- // OperationType
- //===-------------------------------------------------------------------===//
-
- auto operationType =
- mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
- mlirTransformOperationTypeGetTypeID);
- operationType.def_classmethod(
- "get",
- [](const nb::object &cls, const std::string &operationName,
- MlirContext ctx) {
- MlirStringRef cOperationName =
- mlirStringRefCreate(operationName.data(), operationName.size());
- return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
- },
- "Get an instance of OperationType for the given kind in the given "
- "context",
- nb::arg("cls"), nb::arg("operation_name"),
- nb::arg("context") = nb::none());
- operationType.def_property_readonly(
- "operation_name",
- [](MlirType type) {
- MlirStringRef operationName =
- mlirTransformOperationTypeGetOperationName(type);
- return nb::str(operationName.data, operationName.length);
- },
- "Get the name of the payload operation accepted by the handle.");
-
- //===-------------------------------------------------------------------===//
- // ParamType
- //===-------------------------------------------------------------------===//
-
- auto paramType =
- mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
- mlirTransformParamTypeGetTypeID);
- paramType.def_classmethod(
- "get",
- [](const nb::object &cls, MlirType type, MlirContext ctx) {
- return cls(mlirTransformParamTypeGet(ctx, type));
- },
- "Get an instance of ParamType for the given type in the given context.",
- nb::arg("cls"), nb::arg("type"), nb::arg("context") = nb::none());
- paramType.def_property_readonly(
- "type",
- [](MlirType type) {
- MlirType paramType = mlirTransformParamTypeGetType(type);
- return paramType;
- },
- "Get the type this ParamType is associated with.");
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace transform {
+//===-------------------------------------------------------------------===//
+// AnyOpType
+//===-------------------------------------------------------------------===//
+
+struct AnyOpType : PyConcreteType<AnyOpType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyOpType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformAnyOpTypeGetTypeID;
+ static constexpr const char *pyClassName = "AnyOpType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AnyOpType(context->getRef(),
+ mlirTransformAnyOpTypeGet(context.get()->get()));
+ },
+ "Get an instance of AnyOpType in the given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// AnyParamType
+//===-------------------------------------------------------------------===//
+
+struct AnyParamType : PyConcreteType<AnyParamType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyParamType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformAnyParamTypeGetTypeID;
+ static constexpr const char *pyClassName = "AnyParamType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
+ context.get()->get()));
+ },
+ "Get an instance of AnyParamType in the given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// AnyValueType
+//===-------------------------------------------------------------------===//
+
+struct AnyValueType : PyConcreteType<AnyValueType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyValueType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformAnyValueTypeGetTypeID;
+ static constexpr const char *pyClassName = "AnyValueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
+ context.get()->get()));
+ },
+ "Get an instance of AnyValueType in the given context.",
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// OperationType
+//===-------------------------------------------------------------------===//
+
+struct OperationType : PyConcreteType<OperationType> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsATransformOperationType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformOperationTypeGetTypeID;
+ static constexpr const char *pyClassName = "OperationType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &operationName, DefaultingPyMlirContext context) {
+ MlirStringRef cOperationName =
+ mlirStringRefCreate(operationName.data(), operationName.size());
+ return OperationType(context->getRef(),
+ mlirTransformOperationTypeGet(
+ context.get()->get(), cOperationName));
+ },
+ "Get an instance of OperationType for the given kind in the given "
+ "context",
+ nb::arg("operation_name"), nb::arg("context").none() = nb::none());
+ c.def_prop_ro(
+ "operation_name",
+ [](const PyType &type) {
+ MlirStringRef operationName =
+ mlirTransformOperationTypeGetOperationName(type);
+ return nb::str(operationName.data, operationName.length);
+ },
+ "Get the name of the payload operation accepted by the handle.");
+ }
+};
+
+//===-------------------------------------------------------------------===//
+// ParamType
+//===-------------------------------------------------------------------===//
+
+struct ParamType : PyConcreteType<ParamType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformParamType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTransformParamTypeGetTypeID;
+ static constexpr const char *pyClassName = "ParamType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &type, DefaultingPyMlirContext context) {
+ return ParamType(context->getRef(), mlirTransformParamTypeGet(
+ context.get()->get(), type));
+ },
+ "Get an instance of ParamType for the given type in the given context.",
+ nb::arg("type"), nb::arg("context").none() = nb::none());
+ c.def_prop_ro(
+ "type",
+ [](PyType type) {
+ return PyType(type.getContext(), mlirTransformParamTypeGetType(type));
+ },
+ "Get the type this ParamType is associated with.");
+ }
+};
+
+static void populateDialectTransformSubmodule(nb::module_ &m) {
+ AnyOpType::bind(m);
+ AnyParamType::bind(m);
+ AnyValueType::bind(m);
+ OperationType::bind(m);
+ ParamType::bind(m);
}
+} // namespace transform
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsTransform, m) {
m.doc() = "MLIR Transform dialect.";
- populateDialectTransformSubmodule(m);
+ mlir::python::mlir::transform::populateDialectTransformSubmodule(m);
}
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a3..b4d19878056db 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -43,8 +43,9 @@ def __init__(
self.parent = parent
self.children = children if children is not None else []
- at ir.register_value_caster(AnyOpType.get_static_typeid())
- at ir.register_value_caster(OperationType.get_static_typeid())
+
+ at ir.register_value_caster(AnyOpType.static_typeid)
+ at ir.register_value_caster(OperationType.static_typeid)
class OpHandle(Handle):
"""
Wrapper around a transform operation handle with methods to chain further
@@ -132,8 +133,8 @@ def print(self, name: Optional[str] = None) -> "OpHandle":
return self
- at ir.register_value_caster(AnyParamType.get_static_typeid())
- at ir.register_value_caster(ParamType.get_static_typeid())
+ at ir.register_value_caster(AnyParamType.static_typeid)
+ at ir.register_value_caster(ParamType.static_typeid)
class ParamHandle(Handle):
"""Wrapper around a transform param handle."""
@@ -147,7 +148,7 @@ def __init__(
super().__init__(v, parent=parent, children=children)
- at ir.register_value_caster(AnyValueType.get_static_typeid())
+ at ir.register_value_caster(AnyValueType.static_typeid)
class ValueHandle(Handle):
"""
Wrapper around a transform value handle with methods to chain further
diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py
index dfba2a36b8980..f75428d295c9c 100644
--- a/mlir/test/python/dialects/pdl_types.py
+++ b/mlir/test/python/dialects/pdl_types.py
@@ -5,149 +5,149 @@
def run(f):
- print("\nTEST:", f.__name__)
- f()
- return f
+ print("\nTEST:", f.__name__)
+ f()
+ return f
# CHECK-LABEL: TEST: test_attribute_type
@run
def test_attribute_type():
- with Context():
- parsedType = Type.parse("!pdl.attribute")
- constructedType = pdl.AttributeType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.attribute")
+ constructedType = pdl.AttributeType.get()
- assert pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
+ assert pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
- assert pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
+ assert pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.attribute
- print(parsedType)
- # CHECK: !pdl.attribute
- print(constructedType)
+ # CHECK: !pdl.attribute
+ print(parsedType)
+ # CHECK: !pdl.attribute
+ print(constructedType)
# CHECK-LABEL: TEST: test_operation_type
@run
def test_operation_type():
- with Context():
- parsedType = Type.parse("!pdl.operation")
- constructedType = pdl.OperationType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.operation")
+ constructedType = pdl.OperationType.get()
- assert not pdl.AttributeType.isinstance(parsedType)
- assert pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
- assert not pdl.AttributeType.isinstance(constructedType)
- assert pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.operation
- print(parsedType)
- # CHECK: !pdl.operation
- print(constructedType)
+ # CHECK: !pdl.operation
+ print(parsedType)
+ # CHECK: !pdl.operation
+ print(constructedType)
# CHECK-LABEL: TEST: test_range_type
@run
def test_range_type():
- with Context():
- typeType = Type.parse("!pdl.type")
- parsedType = Type.parse("!pdl.range<type>")
- constructedType = pdl.RangeType.get(typeType)
- elementType = constructedType.element_type
-
- assert not pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
-
- assert not pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
-
- assert parsedType == constructedType
- assert elementType == typeType
-
- # CHECK: !pdl.range<type>
- print(parsedType)
- # CHECK: !pdl.range<type>
- print(constructedType)
- # CHECK: !pdl.type
- print(elementType)
+ with Context():
+ typeType = Type.parse("!pdl.type")
+ parsedType = Type.parse("!pdl.range<type>")
+ constructedType = pdl.RangeType.get(typeType)
+ elementType = constructedType.element_type
+
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
+
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
+
+ assert parsedType == constructedType
+ assert elementType == typeType
+
+ # CHECK: !pdl.range<type>
+ print(parsedType)
+ # CHECK: !pdl.range<type>
+ print(constructedType)
+ # CHECK: !pdl.type
+ print(elementType)
# CHECK-LABEL: TEST: test_type_type
@run
def test_type_type():
- with Context():
- parsedType = Type.parse("!pdl.type")
- constructedType = pdl.TypeType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.type")
+ constructedType = pdl.TypeType.get()
- assert not pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert pdl.TypeType.isinstance(parsedType)
- assert not pdl.ValueType.isinstance(parsedType)
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert pdl.TypeType.isinstance(parsedType)
+ assert not pdl.ValueType.isinstance(parsedType)
- assert not pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert pdl.TypeType.isinstance(constructedType)
- assert not pdl.ValueType.isinstance(constructedType)
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert pdl.TypeType.isinstance(constructedType)
+ assert not pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.type
- print(parsedType)
- # CHECK: !pdl.type
- print(constructedType)
+ # CHECK: !pdl.type
+ print(parsedType)
+ # CHECK: !pdl.type
+ print(constructedType)
# CHECK-LABEL: TEST: test_value_type
@run
def test_value_type():
- with Context():
- parsedType = Type.parse("!pdl.value")
- constructedType = pdl.ValueType.get()
+ with Context():
+ parsedType = Type.parse("!pdl.value")
+ constructedType = pdl.ValueType.get()
- assert not pdl.AttributeType.isinstance(parsedType)
- assert not pdl.OperationType.isinstance(parsedType)
- assert not pdl.RangeType.isinstance(parsedType)
- assert not pdl.TypeType.isinstance(parsedType)
- assert pdl.ValueType.isinstance(parsedType)
+ assert not pdl.AttributeType.isinstance(parsedType)
+ assert not pdl.OperationType.isinstance(parsedType)
+ assert not pdl.RangeType.isinstance(parsedType)
+ assert not pdl.TypeType.isinstance(parsedType)
+ assert pdl.ValueType.isinstance(parsedType)
- assert not pdl.AttributeType.isinstance(constructedType)
- assert not pdl.OperationType.isinstance(constructedType)
- assert not pdl.RangeType.isinstance(constructedType)
- assert not pdl.TypeType.isinstance(constructedType)
- assert pdl.ValueType.isinstance(constructedType)
+ assert not pdl.AttributeType.isinstance(constructedType)
+ assert not pdl.OperationType.isinstance(constructedType)
+ assert not pdl.RangeType.isinstance(constructedType)
+ assert not pdl.TypeType.isinstance(constructedType)
+ assert pdl.ValueType.isinstance(constructedType)
- assert parsedType == constructedType
+ assert parsedType == constructedType
- # CHECK: !pdl.value
- print(parsedType)
- # CHECK: !pdl.value
- print(constructedType)
+ # CHECK: !pdl.value
+ print(parsedType)
+ # CHECK: !pdl.value
+ print(constructedType)
# CHECK-LABEL: TEST: test_type_without_context
@@ -157,7 +157,10 @@ def test_type_without_context():
# should raise an exception but not crash.
try:
constructedType = pdl.ValueType.get()
- except TypeError:
- pass
+ except RuntimeError as e:
+ assert (
+ "An MLIR function requires a Context but none was provided in the call or from the surrounding environment"
+ in e.args[0]
+ )
else:
assert False, "Expected TypeError to be raised."
More information about the llvm-branch-commits
mailing list