[Mlir-commits] [llvm] [mlir] [MLIR][Python][NFC] move Py* types (PR #155719)
Maksim Levental
llvmlistbot at llvm.org
Thu Aug 28 00:13:24 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/155719
>From cbcfaee26deb2b83a09f5f635aad116baec673cd Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 27 Aug 2025 19:29:33 -0400
Subject: [PATCH 1/3] [MLIR][Python] make Py* types public API
---
.../mlir}/Bindings/Python/Globals.h | 2 +-
.../mlir/Bindings/Python/IRAttributes.h | 470 ++++
.../mlir}/Bindings/Python/IRModule.h | 54 +-
mlir/include/mlir/Bindings/Python/IRTypes.h | 370 ++-
.../mlir/Bindings/Python/NanobindAdaptors.h | 4 +-
.../mlir}/Bindings/Python/NanobindUtils.h | 0
mlir/lib/Bindings/Python/DialectSMT.cpp | 3 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 48 +-
mlir/lib/Bindings/Python/IRAttributes.cpp | 2430 ++++++++---------
mlir/lib/Bindings/Python/IRCore.cpp | 164 +-
mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 +-
mlir/lib/Bindings/Python/IRModule.cpp | 10 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 1332 ++++-----
mlir/lib/Bindings/Python/MainModule.cpp | 6 +-
mlir/lib/Bindings/Python/Pass.cpp | 6 +-
mlir/lib/Bindings/Python/Pass.h | 2 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 8 +-
mlir/lib/Bindings/Python/Rewrite.h | 2 +-
mlir/python/CMakeLists.txt | 5 -
19 files changed, 2574 insertions(+), 2344 deletions(-)
rename mlir/{lib => include/mlir}/Bindings/Python/Globals.h (99%)
create mode 100644 mlir/include/mlir/Bindings/Python/IRAttributes.h
rename mlir/{lib => include/mlir}/Bindings/Python/IRModule.h (96%)
rename mlir/{lib => include/mlir}/Bindings/Python/NanobindUtils.h (100%)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
similarity index 99%
rename from mlir/lib/Bindings/Python/Globals.h
rename to mlir/include/mlir/Bindings/Python/Globals.h
index 71a051cb3d9f5..9e3b48d7b2e68 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -15,8 +15,8 @@
#include <unordered_set>
#include <vector>
-#include "NanobindUtils.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
new file mode 100644
index 0000000000000..8892437ac3f95
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -0,0 +1,470 @@
+//===- IRAttributes.h - Attribute Interfaces
+//----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+#define MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+
+#include "mlir/Bindings/Python/IRModule.h"
+
+namespace mlir::python {
+
+class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
+ static constexpr const char *pyClassName = "AffineMapAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAffineMapAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PyIntegerSetAttribute
+ : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+ static constexpr const char *pyClassName = "IntegerSetAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerSetAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// A python-wrapped dense array attribute with an element type and a derived
+/// implementation class.
+template <typename EltTy, typename DerivedT>
+class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
+public:
+ using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
+
+ /// Iterator over the integer elements of a dense array.
+ class PyDenseArrayIterator {
+ public:
+ PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ /// Return a copy of the iterator.
+ PyDenseArrayIterator dunderIter();
+
+ /// Return the next element.
+ EltTy dunderNext();
+
+ /// Bind the iterator class.
+ static void bind(nanobind::module_ &m);
+
+ private:
+ /// The referenced dense array attribute.
+ PyAttribute attr;
+ /// The next index to read.
+ int nextIndex = 0;
+ };
+
+ /// Get the element at the given index.
+ EltTy getItem(intptr_t i);
+
+ /// Bind the attribute class.
+ static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c);
+
+private:
+ static DerivedT getAttribute(const std::vector<EltTy> &values,
+ PyMlirContextRef ctx);
+};
+
+/// Instantiate the python dense array classes.
+struct PyDenseBoolArrayAttribute
+ : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
+ static constexpr auto getAttribute = mlirDenseBoolArrayGet;
+ static constexpr auto getElement = mlirDenseBoolArrayGetElement;
+ static constexpr const char *pyClassName = "DenseBoolArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+struct PyDenseI8ArrayAttribute
+ : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
+ static constexpr auto getAttribute = mlirDenseI8ArrayGet;
+ static constexpr auto getElement = mlirDenseI8ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI8ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+struct PyDenseI16ArrayAttribute
+ : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
+ static constexpr auto getAttribute = mlirDenseI16ArrayGet;
+ static constexpr auto getElement = mlirDenseI16ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI16ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+struct PyDenseI32ArrayAttribute
+ : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
+ static constexpr auto getAttribute = mlirDenseI32ArrayGet;
+ static constexpr auto getElement = mlirDenseI32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+struct PyDenseI64ArrayAttribute
+ : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
+ static constexpr auto getAttribute = mlirDenseI64ArrayGet;
+ static constexpr auto getElement = mlirDenseI64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+struct PyDenseF32ArrayAttribute
+ : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
+ static constexpr auto getAttribute = mlirDenseF32ArrayGet;
+ static constexpr auto getElement = mlirDenseF32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+struct PyDenseF64ArrayAttribute
+ : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
+ static constexpr auto getAttribute = mlirDenseF64ArrayGet;
+ static constexpr auto getElement = mlirDenseF64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+ static constexpr const char *pyClassName = "ArrayAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirArrayAttrGetTypeID;
+
+ class PyArrayAttributeIterator {
+ public:
+ PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ PyArrayAttributeIterator &dunderIter();
+
+ MlirAttribute dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+ private:
+ PyAttribute attr;
+ int nextIndex = 0;
+ };
+
+ MlirAttribute getItem(intptr_t i);
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Float Point Attribute subclass - FloatAttr.
+class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+ static constexpr const char *pyClassName = "FloatAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static int64_t toPyInt(PyIntegerAttribute &self);
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+ static constexpr const char *pyClassName = "BoolAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
+ static constexpr const char *pyClassName = "SymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static MlirAttribute fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PyFlatSymbolRefAttribute
+ : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+ static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
+ static constexpr const char *pyClassName = "OpaqueAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
+ static constexpr const char *pyClassName = "StringAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStringAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+struct nb_buffer_info {
+ void *ptr = nullptr;
+ ssize_t itemsize = 0;
+ ssize_t size = 0;
+ const char *format = nullptr;
+ ssize_t ndim = 0;
+ SmallVector<ssize_t, 4> shape;
+ SmallVector<ssize_t, 4> strides;
+ bool readonly = false;
+
+ nb_buffer_info(
+ void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+ SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+ bool readonly = false,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
+ : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
+ shape(std::move(shape_in)), strides(std::move(strides_in)),
+ readonly(readonly), owned_view(std::move(owned_view_in)) {
+ size = 1;
+ for (ssize_t i = 0; i < ndim; ++i) {
+ size *= shape[i];
+ }
+ }
+
+ explicit nb_buffer_info(Py_buffer *view)
+ : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+ {view->shape, view->shape + view->ndim},
+ // TODO(phawkins): check for null strides
+ {view->strides, view->strides + view->ndim},
+ view->readonly != 0,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
+ view, PyBuffer_Release)) {}
+
+ nb_buffer_info(const nb_buffer_info &) = delete;
+ nb_buffer_info(nb_buffer_info &&) = default;
+ nb_buffer_info &operator=(const nb_buffer_info &) = delete;
+ nb_buffer_info &operator=(nb_buffer_info &&) = default;
+
+private:
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
+};
+
+class nb_buffer : public nanobind::object {
+ NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
+
+ nb_buffer_info request() const;
+};
+
+// TODO: Support construction of string elements.
+class PyDenseElementsAttribute
+ : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+ static constexpr const char *pyClassName = "DenseElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseElementsAttribute
+ getFromList(const nanobind::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute
+ getFromBuffer(const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr);
+
+ intptr_t dunderLen();
+
+ std::unique_ptr<nb_buffer_info> accessBuffer();
+
+ static void bindDerived(ClassTy &c);
+
+ static PyType_Slot slots[];
+
+private:
+ static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
+ static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+
+ static bool isUnsignedIntegerFormat(std::string_view format);
+
+ static bool isSignedIntegerFormat(std::string_view format);
+
+ static MlirType
+ getShapedType(std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ Py_buffer &view);
+
+ static MlirAttribute getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context);
+
+ // There is a complication for boolean numpy arrays, as numpy represents
+ // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
+ // booleans per byte.
+ static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context);
+
+ // This does the opposite transformation of
+ // `getBitpackedAttributeFromBooleanBuffer`
+ std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute();
+
+ template <typename Type>
+ std::unique_ptr<nb_buffer_info>
+ bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr);
+}; // namespace
+
+class PyDenseResourceElementsAttribute
+ : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsADenseResourceElements;
+ static constexpr const char *pyClassName = "DenseResourceElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseResourceElementsAttribute
+ getFromBuffer(const nb_buffer &buffer, const std::string &name,
+ const PyType &type, std::optional<size_t> alignment,
+ bool isMutable, DefaultingPyMlirContext contextWrapper);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+ static constexpr const char *pyClassName = "DictAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirDictionaryAttrGetTypeID;
+
+ intptr_t dunderLen();
+
+ bool dunderContains(const std::string &name);
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
+class PyDenseIntElementsAttribute
+ : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+ static constexpr const char *pyClassName = "DenseIntElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
+ nanobind::object dunderGetItem(intptr_t pos);
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class PyDenseFPElementsAttribute
+ : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+ static constexpr const char *pyClassName = "DenseFPElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ nanobind::float_ dunderGetItem(intptr_t pos);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+ static constexpr const char *pyClassName = "TypeAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTypeAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unit Attribute subclass. Unit attributes don't have values.
+class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+ static constexpr const char *pyClassName = "UnitAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnitAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Strided layout attribute subclass.
+class PyStridedLayoutAttribute
+ : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+ static constexpr const char *pyClassName = "StridedLayoutAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStridedLayoutAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+} // namespace mlir::python
+
+#endif // MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRModule.h
similarity index 96%
rename from mlir/lib/Bindings/Python/IRModule.h
rename to mlir/include/mlir/Bindings/Python/IRModule.h
index 6617b41cc916c..57ea36d74cf90 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/include/mlir/Bindings/Python/IRModule.h
@@ -15,16 +15,16 @@
#include <utility>
#include <vector>
-#include "Globals.h"
-#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
+#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ThreadPool.h"
@@ -1320,6 +1320,56 @@ class PySymbolTable {
MlirSymbolTable symbolTable;
};
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nanobind::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig);
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nanobind::module_ &m);
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m);
+};
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Python wrapper for MlirBlockArgument.
+class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
+ static constexpr const char *pyClassName = "BlockArgument";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c);
+};
+
/// Custom exception that allows access to error diagnostic information. This is
/// converted to the `ir.MLIRError` python exception when thrown.
struct MLIRError {
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index ba9642cf2c6a2..60d21fd2f2fa0 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,12 +9,13 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRModule.h"
-namespace mlir {
+namespace mlir::python {
/// Shaped Type Interface - ShapedType
-class PyShapedType : public python::PyConcreteType<PyShapedType> {
+class PyShapedType : public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
static constexpr const char *pyClassName = "ShapedType";
@@ -26,6 +27,367 @@ class PyShapedType : public python::PyConcreteType<PyShapedType> {
void requireHasRank();
};
-} // namespace mlir
+class PyIntegerType : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerTypeGetTypeID;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Index Type subclass - IndexType.
+class PyIndexType : public PyConcreteType<PyIndexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIndexTypeGetTypeID;
+ static constexpr const char *pyClassName = "IndexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class PyFloatType : public PyConcreteType<PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+ static constexpr const char *pyClassName = "FloatType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float4E2M1FNType.
+class PyFloat4E2M1FNType
+ : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat4E2M1FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float4E2M1FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E2M3FNType.
+class PyFloat6E2M3FNType
+ : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E2M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E2M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E3M2FNType.
+class PyFloat6E3M2FNType
+ : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E3M2FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E3M2FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNType.
+class PyFloat8E4M3FNType
+ : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2Type.
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3Type.
+class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class PyFloat8E4M3FNUZType
+ : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class PyFloat8E4M3B11FNUZType
+ : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3B11FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class PyFloat8E5M2FNUZType
+ : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E3M4Type.
+class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E3M4TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E3M4Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class PyFloat8E8M0FNUType
+ : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E8M0FNUTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E8M0FNUType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - BF16Type.
+class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirBFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "BF16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F16Type.
+class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "F16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - TF32Type.
+class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatTF32TypeGetTypeID;
+ static constexpr const char *pyClassName = "FloatTF32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F32Type.
+class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat32TypeGetTypeID;
+ static constexpr const char *pyClassName = "F32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F64Type.
+class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat64TypeGetTypeID;
+ static constexpr const char *pyClassName = "F64Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// None Type subclass - NoneType.
+class PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirNoneTypeGetTypeID;
+ static constexpr const char *pyClassName = "NoneType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Complex Type subclass - ComplexType.
+class PyComplexType : public PyConcreteType<PyComplexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirComplexTypeGetTypeID;
+ static constexpr const char *pyClassName = "ComplexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Vector Type subclass - VectorType.
+class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirVectorTypeGetTypeID;
+ static constexpr const char *pyClassName = "VectorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc);
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class PyRankedTensorType
+ : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "RankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class PyUnrankedTensorType
+ : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "MemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class PyUnrankedMemRefType
+ : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedMemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Tuple Type subclass - TupleType.
+class PyTupleType : public PyConcreteType<PyTupleType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTupleTypeGetTypeID;
+ static constexpr const char *pyClassName = "TupleType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Function type.
+class PyFunctionType : public PyConcreteType<PyFunctionType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFunctionTypeGetTypeID;
+ static constexpr const char *pyClassName = "FunctionType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Opaque Type subclass - OpaqueType.
+class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueTypeGetTypeID;
+ static constexpr const char *pyClassName = "OpaqueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+} // namespace mlir::python
#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 1428d5ccf00f4..35cc52af3334f 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -24,11 +24,11 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
+#include "llvm/ADT/Twine.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
// clang-format on
-#include "llvm/ADT/Twine.h"
// Raw CAPI type casters need to be declared before use, so always include them
// first.
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
similarity index 100%
rename from mlir/lib/Bindings/Python/NanobindUtils.h
rename to mlir/include/mlir/Bindings/Python/NanobindUtils.h
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index cab4219fea72b..e0a2809d4a92c 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -6,8 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "NanobindUtils.h"
-
#include "mlir-c/Dialect/SMT.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@@ -15,6 +13,7 @@
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index a6499c952df6e..7b7bec4df3d00 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -13,19 +13,22 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/IntegerSet.h"
-#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/IRModule.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
+// clang-format off
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
+// clang-format on
+
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@@ -707,25 +710,24 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
[](PyAffineMap &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
- .def_static("compress_unused_symbols",
- [](const nb::list &affineMaps,
- DefaultingPyMlirContext context) {
- SmallVector<MlirAffineMap> maps;
- pyListToVector<PyAffineMap, MlirAffineMap>(
- affineMaps, maps, "attempting to create an AffineMap");
- std::vector<MlirAffineMap> compressed(affineMaps.size());
- auto populate = [](void *result, intptr_t idx,
- MlirAffineMap m) {
- static_cast<MlirAffineMap *>(result)[idx] = (m);
- };
- mlirAffineMapCompressUnusedSymbols(
- maps.data(), maps.size(), compressed.data(), populate);
- std::vector<PyAffineMap> res;
- res.reserve(compressed.size());
- for (auto m : compressed)
- res.emplace_back(context->getRef(), m);
- return res;
- })
+ .def_static(
+ "compress_unused_symbols",
+ [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+ SmallVector<MlirAffineMap> maps;
+ pyListToVector<PyAffineMap, MlirAffineMap>(
+ affineMaps, maps, "attempting to create an AffineMap");
+ std::vector<MlirAffineMap> compressed(affineMaps.size());
+ auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
+ static_cast<MlirAffineMap *>(result)[idx] = (m);
+ };
+ mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(),
+ compressed.data(), populate);
+ std::vector<PyAffineMap> res;
+ res.reserve(compressed.size());
+ for (auto m : compressed)
+ res.emplace_back(context->getRef(), m);
+ return res;
+ })
.def_prop_ro(
"context",
[](PyAffineMap &self) { return self.getContext().getObject(); },
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index af950ce8114fb..648380ce7b6b6 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -12,12 +12,12 @@
#include <string_view>
#include <utility>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
@@ -122,64 +122,6 @@ subsequent processing.
)";
namespace {
-
-struct nb_buffer_info {
- void *ptr = nullptr;
- ssize_t itemsize = 0;
- ssize_t size = 0;
- const char *format = nullptr;
- ssize_t ndim = 0;
- SmallVector<ssize_t, 4> shape;
- SmallVector<ssize_t, 4> strides;
- bool readonly = false;
-
- nb_buffer_info(
- void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
- SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
- bool readonly = false,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
- : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
- shape(std::move(shape_in)), strides(std::move(strides_in)),
- readonly(readonly), owned_view(std::move(owned_view_in)) {
- size = 1;
- for (ssize_t i = 0; i < ndim; ++i) {
- size *= shape[i];
- }
- }
-
- explicit nb_buffer_info(Py_buffer *view)
- : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
- {view->shape, view->shape + view->ndim},
- // TODO(phawkins): check for null strides
- {view->strides, view->strides + view->ndim},
- view->readonly != 0,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
- view, PyBuffer_Release)) {}
-
- nb_buffer_info(const nb_buffer_info &) = delete;
- nb_buffer_info(nb_buffer_info &&) = default;
- nb_buffer_info &operator=(const nb_buffer_info &) = delete;
- nb_buffer_info &operator=(nb_buffer_info &&) = default;
-
-private:
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
-};
-
-class nb_buffer : public nb::object {
- NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
-
- nb_buffer_info request() const {
- int flags = PyBUF_STRIDES | PyBUF_FORMAT;
- auto *view = new Py_buffer();
- if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
- delete view;
- throw nb::python_error();
- }
- return nb_buffer_info(view);
- }
-};
-
template <typename T>
struct nb_format_descriptor {};
@@ -236,47 +178,6 @@ static MlirStringRef toMlirStringRef(const nb::bytes &s) {
return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
}
-class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
- static constexpr const char *pyClassName = "AffineMapAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirAffineMapAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyAffineMap &affineMap) {
- MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
- return PyAffineMapAttribute(affineMap.getContext(), attr);
- },
- nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
- c.def_prop_ro("value", mlirAffineMapAttrGetValue,
- "Returns the value of the AffineMap attribute");
- }
-};
-
-class PyIntegerSetAttribute
- : public PyConcreteAttribute<PyIntegerSetAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
- static constexpr const char *pyClassName = "IntegerSetAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerSetAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyIntegerSet &integerSet) {
- MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
- return PyIntegerSetAttribute(integerSet.getContext(), attr);
- },
- nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
- }
-};
-
template <typename T>
static T pyTryCast(nb::handle object) {
try {
@@ -294,1012 +195,879 @@ static T pyTryCast(nb::handle object) {
}
}
-/// A python-wrapped dense array attribute with an element type and a derived
-/// implementation class.
-template <typename EltTy, typename DerivedT>
-class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
-public:
- using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
-
- /// Iterator over the integer elements of a dense array.
- class PyDenseArrayIterator {
- public:
- PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- /// Return a copy of the iterator.
- PyDenseArrayIterator dunderIter() { return *this; }
-
- /// Return the next element.
- EltTy dunderNext() {
- // Throw if the index has reached the end.
- if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return DerivedT::getElement(attr.get(), nextIndex++);
- }
-
- /// Bind the iterator class.
- static void bind(nb::module_ &m) {
- nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
- .def("__iter__", &PyDenseArrayIterator::dunderIter)
- .def("__next__", &PyDenseArrayIterator::dunderNext);
- }
-
- private:
- /// The referenced dense array attribute.
- PyAttribute attr;
- /// The next index to read.
- int nextIndex = 0;
- };
-
- /// Get the element at the given index.
- EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
-
- /// Bind the attribute class.
- static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
- // Bind the constructor.
- if constexpr (std::is_same_v<EltTy, bool>) {
- c.def_static(
- "get",
- [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
- std::vector<bool> values;
- for (nb::handle py_value : py_values) {
- int is_true = PyObject_IsTrue(py_value.ptr());
- if (is_true < 0) {
- throw nb::python_error();
- }
- values.push_back(is_true);
- }
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued dense array attribute");
- } else {
- c.def_static(
- "get",
- [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued dense array attribute");
- }
- // Bind the array methods.
- c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
- if (i >= mlirDenseArrayGetNumElements(arr))
- throw nb::index_error("DenseArray index out of range");
- return arr.getItem(i);
- });
- c.def("__len__", [](const DerivedT &arr) {
- return mlirDenseArrayGetNumElements(arr);
- });
- c.def("__iter__",
- [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
- c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
- std::vector<EltTy> values;
- intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
- values.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- values.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- values.push_back(pyTryCast<EltTy>(attr));
- return getAttribute(values, arr.getContext());
- });
- }
+} // namespace
-private:
- static DerivedT getAttribute(const std::vector<EltTy> &values,
- PyMlirContextRef ctx) {
- if constexpr (std::is_same_v<EltTy, bool>) {
- std::vector<int> intValues(values.begin(), values.end());
- MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
- intValues.data());
- return DerivedT(ctx, attr);
- } else {
- MlirAttribute attr =
- DerivedT::getAttribute(ctx->get(), values.size(), values.data());
- return DerivedT(ctx, attr);
- }
- }
-};
+namespace mlir::python {
+void PyAffineMapAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyAffineMap &affineMap) {
+ MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+ return PyAffineMapAttribute(affineMap.getContext(), attr);
+ },
+ nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+ c.def_prop_ro("value", mlirAffineMapAttrGetValue,
+ "Returns the value of the AffineMap attribute");
+}
-/// Instantiate the python dense array classes.
-struct PyDenseBoolArrayAttribute
- : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
- static constexpr auto getAttribute = mlirDenseBoolArrayGet;
- static constexpr auto getElement = mlirDenseBoolArrayGetElement;
- static constexpr const char *pyClassName = "DenseBoolArrayAttr";
- static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI8ArrayAttribute
- : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
- static constexpr auto getAttribute = mlirDenseI8ArrayGet;
- static constexpr auto getElement = mlirDenseI8ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI8ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI16ArrayAttribute
- : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
- static constexpr auto getAttribute = mlirDenseI16ArrayGet;
- static constexpr auto getElement = mlirDenseI16ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI16ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI32ArrayAttribute
- : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
- static constexpr auto getAttribute = mlirDenseI32ArrayGet;
- static constexpr auto getElement = mlirDenseI32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI64ArrayAttribute
- : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
- static constexpr auto getAttribute = mlirDenseI64ArrayGet;
- static constexpr auto getElement = mlirDenseI64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF32ArrayAttribute
- : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
- static constexpr auto getAttribute = mlirDenseF32ArrayGet;
- static constexpr auto getElement = mlirDenseF32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF64ArrayAttribute
- : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
- static constexpr auto getAttribute = mlirDenseF64ArrayGet;
- static constexpr auto getElement = mlirDenseF64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
+void PyIntegerSetAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyIntegerSet &integerSet) {
+ MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+ return PyIntegerSetAttribute(integerSet.getContext(), attr);
+ },
+ nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+}
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
- static constexpr const char *pyClassName = "ArrayAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirArrayAttrGetTypeID;
-
- class PyArrayAttributeIterator {
- public:
- PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- PyArrayAttributeIterator &dunderIter() { return *this; }
-
- MlirAttribute dunderNext() {
- // TODO: Throw is an inefficient way to stop iteration.
- if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return mlirArrayAttrGetElement(attr.get(), nextIndex++);
- }
+template <typename EltTy, typename DerivedT>
+typename PyDenseArrayAttribute<EltTy, DerivedT>::PyDenseArrayIterator
+PyDenseArrayAttribute<EltTy, DerivedT>::PyDenseArrayIterator::dunderIter() {
+ return *this;
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
- .def("__iter__", &PyArrayAttributeIterator::dunderIter)
- .def("__next__", &PyArrayAttributeIterator::dunderNext);
- }
+template <typename EltTy, typename DerivedT>
+EltTy PyDenseArrayAttribute<EltTy,
+ DerivedT>::PyDenseArrayIterator::dunderNext() {
+ // Throw if the index has reached the end.
+ if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
+ throw nb::stop_iteration();
+ return DerivedT::getElement(attr.get(), nextIndex++);
+}
- private:
- PyAttribute attr;
- int nextIndex = 0;
- };
+template <typename EltTy, typename DerivedT>
+void PyDenseArrayAttribute<EltTy, DerivedT>::PyDenseArrayIterator::bind(
+ nb::module_ &m) {
+ nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
+ .def("__iter__", &PyDenseArrayIterator::dunderIter)
+ .def("__next__", &PyDenseArrayIterator::dunderNext);
+}
- MlirAttribute getItem(intptr_t i) {
- return mlirArrayAttrGetElement(*this, i);
- }
+template <typename EltTy, typename DerivedT>
+EltTy PyDenseArrayAttribute<EltTy, DerivedT>::getItem(intptr_t i) {
+ return DerivedT::getElement(*this, i);
+}
- static void bindDerived(ClassTy &c) {
+template <typename EltTy, typename DerivedT>
+void PyDenseArrayAttribute<EltTy, DerivedT>::bindDerived(
+ typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
+ // Bind the constructor.
+ if constexpr (std::is_same_v<EltTy, bool>) {
c.def_static(
"get",
- [](const nb::list &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(nb::len(attributes));
- for (auto attribute : attributes) {
- mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
+ [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
+ std::vector<bool> values;
+ for (nb::handle py_value : py_values) {
+ int is_true = PyObject_IsTrue(py_value.ptr());
+ if (is_true < 0) {
+ throw nb::python_error();
+ }
+ values.push_back(is_true);
}
- MlirAttribute attr = mlirArrayAttrGet(
- context->get(), mlirAttributes.size(), mlirAttributes.data());
- return PyArrayAttribute(context->getRef(), attr);
- },
- nb::arg("attributes"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued Array attribute");
- c.def("__getitem__",
- [](PyArrayAttribute &arr, intptr_t i) {
- if (i >= mlirArrayAttrGetNumElements(arr))
- throw nb::index_error("ArrayAttribute index out of range");
- return arr.getItem(i);
- })
- .def("__len__",
- [](const PyArrayAttribute &arr) {
- return mlirArrayAttrGetNumElements(arr);
- })
- .def("__iter__", [](const PyArrayAttribute &arr) {
- return PyArrayAttributeIterator(arr);
- });
- c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
- std::vector<MlirAttribute> attributes;
- intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
- attributes.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- attributes.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- attributes.push_back(pyTryCast<PyAttribute>(attr));
- MlirAttribute arrayAttr = mlirArrayAttrGet(
- arr.getContext()->get(), attributes.size(), attributes.data());
- return PyArrayAttribute(arr.getContext(), arrayAttr);
- });
- }
-};
-
-/// Float Point Attribute subclass - FloatAttr.
-class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
- static constexpr const char *pyClassName = "FloatAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, double value, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Invalid attribute", errors.take());
- return PyFloatAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_f32",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF32TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
+ return getAttribute(values, ctx->getRef());
},
- nb::arg("value"), nb::arg("context").none() = nb::none(),
- "Gets an uniqued float point attribute associated to a f32 type");
- c.def_static(
- "get_f64",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF64TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context").none() = nb::none(),
- "Gets an uniqued float point attribute associated to a f64 type");
- c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
- "Returns the value of the float attribute");
- c.def("__float__", mlirFloatAttrGetValueDouble,
- "Converts the value of the float attribute to a Python float");
- }
-};
-
-/// Integer Attribute subclass - IntegerAttr.
-class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
- static constexpr const char *pyClassName = "IntegerAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
+ nb::arg("values"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued dense array attribute");
+ } else {
c.def_static(
"get",
- [](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type, value);
- return PyIntegerAttribute(type.getContext(), attr);
+ [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
+ return getAttribute(values, ctx->getRef());
},
- nb::arg("type"), nb::arg("value"),
- "Gets an uniqued integer attribute associated to a type");
- c.def_prop_ro("value", toPyInt,
- "Returns the value of the integer attribute");
- c.def("__int__", toPyInt,
- "Converts the value of the integer attribute to a Python int");
- c.def_prop_ro_static("static_typeid",
- [](nb::object & /*class*/) -> MlirTypeID {
- return mlirIntegerAttrGetTypeID();
- });
+ nb::arg("values"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued dense array attribute");
}
+ // Bind the array methods.
+ c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
+ if (i >= mlirDenseArrayGetNumElements(arr))
+ throw nb::index_error("DenseArray index out of range");
+ return arr.getItem(i);
+ });
+ c.def("__len__",
+ [](const DerivedT &arr) { return mlirDenseArrayGetNumElements(arr); });
+ c.def("__iter__",
+ [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
+ c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
+ std::vector<EltTy> values;
+ intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
+ values.reserve(numOldElements + nb::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ values.push_back(arr.getItem(i));
+ for (nb::handle attr : extras)
+ values.push_back(pyTryCast<EltTy>(attr));
+ return getAttribute(values, arr.getContext());
+ });
+}
-private:
- static int64_t toPyInt(PyIntegerAttribute &self) {
- MlirType type = mlirAttributeGetType(self);
- if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
- return mlirIntegerAttrGetValueInt(self);
- if (mlirIntegerTypeIsSigned(type))
- return mlirIntegerAttrGetValueSInt(self);
- return mlirIntegerAttrGetValueUInt(self);
+template <typename EltTy, typename DerivedT>
+DerivedT PyDenseArrayAttribute<EltTy, DerivedT>::getAttribute(
+ const std::vector<EltTy> &values, PyMlirContextRef ctx) {
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ std::vector<int> intValues(values.begin(), values.end());
+ MlirAttribute attr =
+ DerivedT::getAttribute(ctx->get(), intValues.size(), intValues.data());
+ return DerivedT(ctx, attr);
+ } else {
+ MlirAttribute attr =
+ DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+ return DerivedT(ctx, attr);
}
-};
+}
-/// Bool Attribute subclass - BoolAttr.
-class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
- static constexpr const char *pyClassName = "BoolAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
+PyArrayAttribute::PyArrayAttributeIterator &
+PyArrayAttribute::PyArrayAttributeIterator::dunderIter() {
+ return *this;
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](bool value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
- return PyBoolAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context").none() = nb::none(),
- "Gets an uniqued bool attribute");
- c.def_prop_ro("value", mlirBoolAttrGetValue,
- "Returns the value of the bool attribute");
- c.def("__bool__", mlirBoolAttrGetValue,
- "Converts the value of the bool attribute to a Python bool");
- }
-};
+MlirAttribute PyArrayAttribute::PyArrayAttributeIterator::dunderNext() {
+ // TODO: Throw is an inefficient way to stop iteration.
+ if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
+ throw nb::stop_iteration();
+ return mlirArrayAttrGetElement(attr.get(), nextIndex++);
+}
-class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
- static constexpr const char *pyClassName = "SymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static MlirAttribute fromList(const std::vector<std::string> &symbols,
- PyMlirContext &context) {
- if (symbols.empty())
- throw std::runtime_error("SymbolRefAttr must be composed of at least "
- "one symbol.");
- MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
- SmallVector<MlirAttribute, 3> referenceAttrs;
- for (size_t i = 1; i < symbols.size(); ++i) {
- referenceAttrs.push_back(
- mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
- }
- return mlirSymbolRefAttrGet(context.get(), rootSymbol,
- referenceAttrs.size(), referenceAttrs.data());
- }
+void PyArrayAttribute::PyArrayAttributeIterator::bind(nb::module_ &m) {
+ nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+ .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+ .def("__next__", &PyArrayAttributeIterator::dunderNext);
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::vector<std::string> &symbols,
- DefaultingPyMlirContext context) {
- return PySymbolRefAttribute::fromList(symbols, context.resolve());
- },
- nb::arg("symbols"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued SymbolRef attribute from a list of symbol names");
- c.def_prop_ro(
- "value",
- [](PySymbolRefAttribute &self) {
- std::vector<std::string> symbols = {
- unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
- for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
- ++i)
- symbols.push_back(
- unwrap(mlirSymbolRefAttrGetRootReference(
- mlirSymbolRefAttrGetNestedReference(self, i)))
- .str());
- return symbols;
- },
- "Returns the value of the SymbolRef attribute as a list[str]");
- }
-};
+MlirAttribute PyArrayAttribute::getItem(intptr_t i) {
+ return mlirArrayAttrGetElement(*this, i);
+}
-class PyFlatSymbolRefAttribute
- : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
- static constexpr const char *pyClassName = "FlatSymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
+void PyArrayAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const nb::list &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(nb::len(attributes));
+ for (auto attribute : attributes) {
+ mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
+ }
+ MlirAttribute attr = mlirArrayAttrGet(
+ context->get(), mlirAttributes.size(), mlirAttributes.data());
+ return PyArrayAttribute(context->getRef(), attr);
+ },
+ nb::arg("attributes"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued Array attribute");
+ c.def("__getitem__",
+ [](PyArrayAttribute &arr, intptr_t i) {
+ if (i >= mlirArrayAttrGetNumElements(arr))
+ throw nb::index_error("ArrayAttribute index out of range");
+ return arr.getItem(i);
+ })
+ .def("__len__",
+ [](const PyArrayAttribute &arr) {
+ return mlirArrayAttrGetNumElements(arr);
+ })
+ .def("__iter__", [](const PyArrayAttribute &arr) {
+ return PyArrayAttributeIterator(arr);
+ });
+ c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
+ std::vector<MlirAttribute> attributes;
+ intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+ attributes.reserve(numOldElements + nb::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ attributes.push_back(arr.getItem(i));
+ for (nb::handle attr : extras)
+ attributes.push_back(pyTryCast<PyAttribute>(attr));
+ MlirAttribute arrayAttr = mlirArrayAttrGet(
+ arr.getContext()->get(), attributes.size(), attributes.data());
+ return PyArrayAttribute(arr.getContext(), arrayAttr);
+ });
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
- return PyFlatSymbolRefAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued FlatSymbolRef attribute");
- c.def_prop_ro(
- "value",
- [](PyFlatSymbolRefAttribute &self) {
- MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the FlatSymbolRef attribute as a string");
- }
-};
+void PyFloatAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, double value, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_f32",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF32TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets an uniqued float point attribute associated to a f32 type");
+ c.def_static(
+ "get_f64",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF64TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets an uniqued float point attribute associated to a f64 type");
+ c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
+ "Returns the value of the float attribute");
+ c.def("__float__", mlirFloatAttrGetValueDouble,
+ "Converts the value of the float attribute to a Python float");
+}
-class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
- static constexpr const char *pyClassName = "OpaqueAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueAttrGetTypeID;
+void PyIntegerAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, int64_t value) {
+ MlirAttribute attr = mlirIntegerAttrGet(type, value);
+ return PyIntegerAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"),
+ "Gets an uniqued integer attribute associated to a type");
+ c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute");
+ c.def("__int__", toPyInt,
+ "Converts the value of the integer attribute to a Python int");
+ c.def_prop_ro_static("static_typeid",
+ [](nb::object & /*class*/) -> MlirTypeID {
+ return mlirIntegerAttrGetTypeID();
+ });
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const nb_buffer &buffer,
- PyType &type, DefaultingPyMlirContext context) {
- const nb_buffer_info bufferInfo = buffer.request();
- intptr_t bufferSize = bufferInfo.size;
- MlirAttribute attr = mlirOpaqueAttrGet(
- context->get(), toMlirStringRef(dialectNamespace), bufferSize,
- static_cast<char *>(bufferInfo.ptr), type);
- return PyOpaqueAttribute(context->getRef(), attr);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
- nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque attribute as a string");
- c.def_prop_ro(
- "data",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
- return nb::bytes(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaqued attributes as `bytes`");
- }
-};
+int64_t PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
+ MlirType type = mlirAttributeGetType(self);
+ if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+ return mlirIntegerAttrGetValueInt(self);
+ if (mlirIntegerTypeIsSigned(type))
+ return mlirIntegerAttrGetValueSInt(self);
+ return mlirIntegerAttrGetValueUInt(self);
+}
-class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
- static constexpr const char *pyClassName = "StringAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirStringAttrGetTypeID;
+void PyBoolAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](bool value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+ return PyBoolAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets an uniqued bool attribute");
+ c.def_prop_ro("value", mlirBoolAttrGetValue,
+ "Returns the value of the bool attribute");
+ c.def("__bool__", mlirBoolAttrGetValue,
+ "Converts the value of the bool attribute to a Python bool");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirStringAttrGet(context->get(), toMlirStringRef(value));
- return PyStringAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued string attribute");
- c.def_static(
- "get",
- [](const nb::bytes &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirStringAttrGet(context->get(), toMlirStringRef(value));
- return PyStringAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued string attribute");
- c.def_static(
- "get_typed",
- [](PyType &type, const std::string &value) {
- MlirAttribute attr =
- mlirStringAttrTypedGet(type, toMlirStringRef(value));
- return PyStringAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"),
- "Gets a uniqued string attribute associated to a type");
- c.def_prop_ro(
- "value",
- [](PyStringAttribute &self) {
- MlirStringRef stringRef = mlirStringAttrGetValue(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the string attribute");
- c.def_prop_ro(
- "value_bytes",
- [](PyStringAttribute &self) {
- MlirStringRef stringRef = mlirStringAttrGetValue(self);
- return nb::bytes(stringRef.data, stringRef.length);
- },
- "Returns the value of the string attribute as `bytes`");
+MlirAttribute
+PySymbolRefAttribute::fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context) {
+ if (symbols.empty())
+ throw std::runtime_error("SymbolRefAttr must be composed of at least "
+ "one symbol.");
+ MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
+ SmallVector<MlirAttribute, 3> referenceAttrs;
+ for (size_t i = 1; i < symbols.size(); ++i) {
+ referenceAttrs.push_back(
+ mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
}
-};
-
-// TODO: Support construction of string elements.
-class PyDenseElementsAttribute
- : public PyConcreteAttribute<PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
- static constexpr const char *pyClassName = "DenseElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseElementsAttribute
- getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
- DefaultingPyMlirContext contextWrapper) {
- const size_t numAttributes = nb::len(attributes);
- if (numAttributes == 0)
- throw nb::value_error("Attributes list must be non-empty.");
-
- MlirType shapedType;
- if (explicitType) {
- if ((!mlirTypeIsAShaped(*explicitType) ||
- !mlirShapedTypeHasStaticShape(*explicitType))) {
-
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "Expected a static ShapedType for the shaped_type parameter: "
- << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
- throw nb::value_error(message.c_str());
- }
- shapedType = *explicitType;
- } else {
- SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
- shapedType = mlirRankedTensorTypeGet(
- shape.size(), shape.data(),
- mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
- mlirAttributeGetNull());
- }
+ return mlirSymbolRefAttrGet(context.get(), rootSymbol, referenceAttrs.size(),
+ referenceAttrs.data());
+}
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(numAttributes);
- for (const nb::handle &attribute : attributes) {
- MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
- MlirType attrType = mlirAttributeGetType(mlirAttribute);
- mlirAttributes.push_back(mlirAttribute);
-
- if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "All attributes must be of the same type and match "
- << "the type parameter: expected="
- << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
- << ", but got="
- << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
- throw nb::value_error(message.c_str());
- }
- }
+void PySymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::vector<std::string> &symbols,
+ DefaultingPyMlirContext context) {
+ return PySymbolRefAttribute::fromList(symbols, context.resolve());
+ },
+ nb::arg("symbols"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued SymbolRef attribute from a list of symbol names");
+ c.def_prop_ro(
+ "value",
+ [](PySymbolRefAttribute &self) {
+ std::vector<std::string> symbols = {
+ unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
+ for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); ++i)
+ symbols.push_back(
+ unwrap(mlirSymbolRefAttrGetRootReference(
+ mlirSymbolRefAttrGetNestedReference(self, i)))
+ .str());
+ return symbols;
+ },
+ "Returns the value of the SymbolRef attribute as a list[str]");
+}
- MlirAttribute elements = mlirDenseElementsAttrGet(
- shapedType, mlirAttributes.size(), mlirAttributes.data());
+void PyFlatSymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+ return PyFlatSymbolRefAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued FlatSymbolRef attribute");
+ c.def_prop_ro(
+ "value",
+ [](PyFlatSymbolRefAttribute &self) {
+ MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the FlatSymbolRef attribute as a string");
+}
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+nb_buffer_info nb_buffer::request() const {
+ int flags = PyBUF_STRIDES | PyBUF_FORMAT;
+ auto *view = new Py_buffer();
+ if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
+ delete view;
+ throw nb::python_error();
}
+ return nb_buffer_info(view);
+}
- static PyDenseElementsAttribute
- getFromBuffer(const nb_buffer &array, bool signless,
- const std::optional<PyType> &explicitType,
- std::optional<std::vector<int64_t>> explicitShape,
- DefaultingPyMlirContext contextWrapper) {
- // Request a contiguous view. In exotic cases, this will cause a copy.
- int flags = PyBUF_ND;
- if (!explicitType) {
- flags |= PyBUF_FORMAT;
- }
- Py_buffer view;
- if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
- throw nb::python_error();
- }
- auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
+void PyOpaqueAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const nb_buffer &buffer,
+ PyType &type, DefaultingPyMlirContext context) {
+ const nb_buffer_info bufferInfo = buffer.request();
+ intptr_t bufferSize = bufferInfo.size;
+ MlirAttribute attr = mlirOpaqueAttrGet(
+ context->get(), toMlirStringRef(dialectNamespace), bufferSize,
+ static_cast<char *>(bufferInfo.ptr), type);
+ return PyOpaqueAttribute(context->getRef(), attr);
+ },
+ nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
+ nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque attribute as a string");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaqued attributes as `bytes`");
+}
- MlirContext context = contextWrapper->get();
- MlirAttribute attr = getAttributeFromBuffer(
- view, signless, explicitType, std::move(explicitShape), context);
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseElementsAttr could not be constructed from the given buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
- }
+void PyStringAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirStringAttrGet(context->get(), toMlirStringRef(value));
+ return PyStringAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued string attribute");
+ c.def_static(
+ "get",
+ [](const nb::bytes &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirStringAttrGet(context->get(), toMlirStringRef(value));
+ return PyStringAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued string attribute");
+ c.def_static(
+ "get_typed",
+ [](PyType &type, const std::string &value) {
+ MlirAttribute attr =
+ mlirStringAttrTypedGet(type, toMlirStringRef(value));
+ return PyStringAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"),
+ "Gets a uniqued string attribute associated to a type");
+ c.def_prop_ro(
+ "value",
+ [](PyStringAttribute &self) {
+ MlirStringRef stringRef = mlirStringAttrGetValue(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the string attribute");
+ c.def_prop_ro(
+ "value_bytes",
+ [](PyStringAttribute &self) {
+ MlirStringRef stringRef = mlirStringAttrGetValue(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the string attribute as `bytes`");
+}
- static PyDenseElementsAttribute getSplat(const PyType &shapedType,
- PyAttribute &elementAttr) {
- auto contextWrapper =
- PyMlirContext::forContext(mlirTypeGetContext(shapedType));
- if (!mlirAttributeIsAInteger(elementAttr) &&
- !mlirAttributeIsAFloat(elementAttr)) {
- std::string message = "Illegal element type for DenseElementsAttr: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
- throw nb::value_error(message.c_str());
- }
- if (!mlirTypeIsAShaped(shapedType) ||
- !mlirShapedTypeHasStaticShape(shapedType)) {
- std::string message =
- "Expected a static ShapedType for the shaped_type parameter: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getFromList(const nb::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+ const size_t numAttributes = nb::len(attributes);
+ if (numAttributes == 0)
+ throw nb::value_error("Attributes list must be non-empty.");
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Expected a static ShapedType for the shaped_type parameter: "
+ << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
throw nb::value_error(message.c_str());
}
- MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
- MlirType attrType = mlirAttributeGetType(elementAttr);
- if (!mlirTypeEqual(shapedElementType, attrType)) {
- std::string message =
- "Shaped element type and attribute type must be equal: shaped=";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
- message.append(", element=");
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
+ }
+
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(numAttributes);
+ for (const nb::handle &attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "All attributes must be of the same type and match "
+ << "the type parameter: expected="
+ << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
+ << ", but got=" << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
throw nb::value_error(message.c_str());
}
+ }
+
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
- MlirAttribute elements =
- mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+PyDenseElementsAttribute PyDenseElementsAttribute::getFromBuffer(
+ const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper) {
+ // Request a contiguous view. In exotic cases, this will cause a copy.
+ int flags = PyBUF_ND;
+ if (!explicitType) {
+ flags |= PyBUF_FORMAT;
+ }
+ Py_buffer view;
+ if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
+ throw nb::python_error();
}
+ auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
- intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
+ MlirContext context = contextWrapper->get();
+ MlirAttribute attr = getAttributeFromBuffer(
+ view, signless, explicitType, std::move(explicitShape), context);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
+ }
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+}
- std::unique_ptr<nb_buffer_info> accessBuffer() {
- MlirType shapedType = mlirAttributeGetType(*this);
- MlirType elementType = mlirShapedTypeGetElementType(shapedType);
- std::string format;
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr) {
+ auto contextWrapper =
+ PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+ if (!mlirAttributeIsAInteger(elementAttr) &&
+ !mlirAttributeIsAFloat(elementAttr)) {
+ std::string message = "Illegal element type for DenseElementsAttr: ";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
+ }
+ if (!mlirTypeIsAShaped(shapedType) ||
+ !mlirShapedTypeHasStaticShape(shapedType)) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+ throw nb::value_error(message.c_str());
+ }
+ MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+ MlirType attrType = mlirAttributeGetType(elementAttr);
+ if (!mlirTypeEqual(shapedElementType, attrType)) {
+ std::string message =
+ "Shaped element type and attribute type must be equal: shaped=";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+ message.append(", element=");
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
+ }
- if (mlirTypeIsAF32(elementType)) {
- // f32
- return bufferInfo<float>(shapedType);
- }
- if (mlirTypeIsAF64(elementType)) {
- // f64
- return bufferInfo<double>(shapedType);
+ MlirAttribute elements =
+ mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
+
+intptr_t PyDenseElementsAttribute::dunderLen() {
+ return mlirElementsAttrGetNumElements(*this);
+}
+
+std::unique_ptr<nb_buffer_info> PyDenseElementsAttribute::accessBuffer() {
+ MlirType shapedType = mlirAttributeGetType(*this);
+ MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+ std::string format;
+
+ if (mlirTypeIsAF32(elementType)) {
+ // f32
+ return bufferInfo<float>(shapedType);
+ }
+ if (mlirTypeIsAF64(elementType)) {
+ // f64
+ return bufferInfo<double>(shapedType);
+ }
+ if (mlirTypeIsAF16(elementType)) {
+ // f16
+ return bufferInfo<uint16_t>(shapedType, "e");
+ }
+ if (mlirTypeIsAIndex(elementType)) {
+ // Same as IndexType::kInternalStorageBitWidth
+ return bufferInfo<int64_t>(shapedType);
+ }
+ if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 32) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i32
+ return bufferInfo<int32_t>(shapedType);
}
- if (mlirTypeIsAF16(elementType)) {
- // f16
- return bufferInfo<uint16_t>(shapedType, "e");
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i32
+ return bufferInfo<uint32_t>(shapedType);
}
- if (mlirTypeIsAIndex(elementType)) {
- // Same as IndexType::kInternalStorageBitWidth
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 64) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i64
return bufferInfo<int64_t>(shapedType);
}
- if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 32) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i32
- return bufferInfo<int32_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i32
- return bufferInfo<uint32_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 64) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i64
- return bufferInfo<int64_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i64
- return bufferInfo<uint64_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 8) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i8
- return bufferInfo<int8_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i8
- return bufferInfo<uint8_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 16) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i16
- return bufferInfo<int16_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i16
- return bufferInfo<uint16_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 1) {
- // i1 / bool
- // We can not send the buffer directly back to Python, because the i1
- // values are bitpacked within MLIR. We call numpy's unpackbits function
- // to convert the bytes.
- return getBooleanBufferFromBitpackedAttribute();
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i64
+ return bufferInfo<uint64_t>(shapedType);
}
-
- // TODO: Currently crashes the program.
- // Reported as https://github.com/pybind/pybind11/issues/3336
- throw std::invalid_argument(
- "unsupported data type for conversion to Python buffer");
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 8) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i8
+ return bufferInfo<int8_t>(shapedType);
+ }
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i8
+ return bufferInfo<uint8_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 16) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i16
+ return bufferInfo<int16_t>(shapedType);
+ }
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i16
+ return bufferInfo<uint16_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 1) {
+ // i1 / bool
+ // We can not send the buffer directly back to Python, because the i1
+ // values are bitpacked within MLIR. We call numpy's unpackbits function
+ // to convert the bytes.
+ return getBooleanBufferFromBitpackedAttribute();
}
- static void bindDerived(ClassTy &c) {
+ // TODO: Currently crashes the program.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ throw std::invalid_argument(
+ "unsupported data type for conversion to Python buffer");
+}
+
+void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
#if PY_VERSION_HEX < 0x03090000
- PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
- tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
- tp->tp_as_buffer->bf_releasebuffer =
- PyDenseElementsAttribute::bf_releasebuffer;
+ PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
+ tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
+ tp->tp_as_buffer->bf_releasebuffer =
+ PyDenseElementsAttribute::bf_releasebuffer;
#endif
- c.def("__len__", &PyDenseElementsAttribute::dunderLen)
- .def_static("get", PyDenseElementsAttribute::getFromBuffer,
- nb::arg("array"), nb::arg("signless") = true,
- nb::arg("type").none() = nb::none(),
- nb::arg("shape").none() = nb::none(),
- nb::arg("context").none() = nb::none(),
- kDenseElementsAttrGetDocstring)
- .def_static("get", PyDenseElementsAttribute::getFromList,
- nb::arg("attrs"), nb::arg("type").none() = nb::none(),
- nb::arg("context").none() = nb::none(),
- kDenseElementsAttrGetFromListDocstring)
- .def_static("get_splat", PyDenseElementsAttribute::getSplat,
- nb::arg("shaped_type"), nb::arg("element_attr"),
- "Gets a DenseElementsAttr where all values are the same")
- .def_prop_ro("is_splat",
- [](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self);
- })
- .def("get_splat_value", [](PyDenseElementsAttribute &self) {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw nb::value_error(
- "get_splat_value called on a non-splat attribute");
- return mlirDenseElementsAttrGetSplatValue(self);
- });
- }
-
- static PyType_Slot slots[];
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+ .def_static("get", PyDenseElementsAttribute::getFromBuffer,
+ nb::arg("array"), nb::arg("signless") = true,
+ nb::arg("type").none() = nb::none(),
+ nb::arg("shape").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ nb::arg("attrs"), nb::arg("type").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ kDenseElementsAttrGetFromListDocstring)
+ .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+ nb::arg("shaped_type"), nb::arg("element_attr"),
+ "Gets a DenseElementsAttr where all values are the same")
+ .def_prop_ro("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self);
+ })
+ .def("get_splat_value", [](PyDenseElementsAttribute &self) {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nb::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return mlirDenseElementsAttrGetSplatValue(self);
+ });
+}
-private:
- static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
- static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+bool PyDenseElementsAttribute::isUnsignedIntegerFormat(
+ std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+ code == 'Q';
+}
- static bool isUnsignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
- code == 'Q';
- }
+bool PyDenseElementsAttribute::isSignedIntegerFormat(std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+ code == 'q';
+}
- static bool isSignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
- code == 'q';
+MlirType PyDenseElementsAttribute::getShapedType(
+ std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape, Py_buffer &view) {
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(view.shape, view.shape + view.ndim);
}
- static MlirType
- getShapedType(std::optional<MlirType> bulkLoadElementType,
- std::optional<std::vector<int64_t>> explicitShape,
- Py_buffer &view) {
- SmallVector<int64_t> shape;
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
- shape.append(explicitShape->begin(), explicitShape->end());
- } else {
- shape.append(view.shape, view.shape + view.ndim);
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
}
-
- if (mlirTypeIsAShaped(*bulkLoadElementType)) {
- if (explicitShape) {
- throw std::invalid_argument("Shape can only be specified explicitly "
- "when the type is not a shaped type.");
- }
- return *bulkLoadElementType;
- }
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- return mlirRankedTensorTypeGet(shape.size(), shape.data(),
- *bulkLoadElementType, encodingAttr);
+ return *bulkLoadElementType;
}
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+}
- static MlirAttribute getAttributeFromBuffer(
- Py_buffer &view, bool signless, std::optional<PyType> explicitType,
- const std::optional<std::vector<int64_t>> &explicitShape,
- MlirContext &context) {
- // Detect format codes that are suitable for bulk loading. This includes
- // all byte aligned integer and floating point types up to 8 bytes.
- // Notably, this excludes exotics types which do not have a direct
- // representation in the buffer protocol (i.e. complex, etc).
- std::optional<MlirType> bulkLoadElementType;
- if (explicitType) {
- bulkLoadElementType = *explicitType;
- } else {
- std::string_view format(view.format);
- if (format == "f") {
- // f32
- assert(view.itemsize == 4 && "mismatched array itemsize");
- bulkLoadElementType = mlirF32TypeGet(context);
- } else if (format == "d") {
- // f64
- assert(view.itemsize == 8 && "mismatched array itemsize");
- bulkLoadElementType = mlirF64TypeGet(context);
- } else if (format == "e") {
- // f16
- assert(view.itemsize == 2 && "mismatched array itemsize");
- bulkLoadElementType = mlirF16TypeGet(context);
- } else if (format == "?") {
- // i1
- // The i1 type needs to be bit-packed, so we will handle it seperately
- return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
- context);
- } else if (isSignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeSignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeSignedGet(context, 16);
- }
- } else if (isUnsignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // unsigned i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // unsigned i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeUnsignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeUnsignedGet(context, 16);
- }
+MlirAttribute PyDenseElementsAttribute::getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context) {
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes exotics types which do not have a direct
+ // representation in the buffer protocol (i.e. complex, etc).
+ std::optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
+ // f32
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
+ // f64
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
+ // f16
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (format == "?") {
+ // i1
+ // The i1 type needs to be bit-packed, so we will handle it seperately
+ return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+ context);
+ } else if (isSignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // i32
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // i64
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
}
- if (!bulkLoadElementType) {
- throw std::invalid_argument(
- std::string("unimplemented array format conversion from format: ") +
- std::string(format));
+ } else if (isUnsignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // unsigned i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // unsigned i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
}
-
- MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
- return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
- }
-
- // There is a complication for boolean numpy arrays, as numpy represents
- // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
- // booleans per byte.
- static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
- Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
- MlirContext &context) {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a bit-packed MLIR attribute is "
- "unsupported on big-endian systems");
+ if (!bulkLoadElementType) {
+ throw std::invalid_argument(
+ std::string("unimplemented array format conversion from format: ") +
+ std::string(format));
}
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
- /*data=*/static_cast<uint8_t *>(view.buf),
- /*shape=*/{static_cast<size_t>(view.len)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object packbitsFunc = numpy.attr("packbits");
- nb::object packedBooleans =
- packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
- nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
-
- MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
- std::move(explicitShape), view);
- assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
- // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
- // packedBooleans, hence the MlirAttribute will remain valid even when
- // packedBooleans get reclaimed by the end of the function.
- return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
- pythonBuffer.ptr);
}
- // This does the opposite transformation of
- // `getBitpackedAttributeFromBooleanBuffer`
- std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a numpy array from a MLIR attribute "
- "is unsupported on big-endian systems");
- }
+ MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+ return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+}
- int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
- int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
- uint8_t *bitpackedData = static_cast<uint8_t *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
- /*data=*/bitpackedData,
- /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object unpackbitsFunc = numpy.attr("unpackbits");
- nb::object equalFunc = numpy.attr("equal");
- nb::object reshapeFunc = numpy.attr("reshape");
- nb::object unpackedBooleans =
- unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
-
- // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
- // We need to:
- // 1. Slice away the padded bits
- // 2. Make the boolean array have the correct shape
- // 3. Convert the array to a boolean array
- unpackedBooleans = unpackedBooleans[nb::slice(
- nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
- unpackedBooleans = equalFunc(unpackedBooleans, 1);
-
- MlirType shapedType = mlirAttributeGetType(*this);
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- std::vector<intptr_t> shape(rank);
- for (intptr_t i = 0; i < rank; ++i) {
- shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
- }
- unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+// There is a complication for boolean numpy arrays, as numpy represents
+// them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
+// booleans per byte.
+MlirAttribute PyDenseElementsAttribute::getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context) {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a bit-packed MLIR attribute is "
+ "unsupported on big-endian systems");
+ }
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
+ /*data=*/static_cast<uint8_t *>(view.buf),
+ /*shape=*/{static_cast<size_t>(view.len)});
+
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object packbitsFunc = numpy.attr("packbits");
+ nb::object packedBooleans =
+ packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
+ nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
+
+ MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
+ std::move(explicitShape), view);
+ assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
+ // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
+ // packedBooleans, hence the MlirAttribute will remain valid even when
+ // packedBooleans get reclaimed by the end of the function.
+ return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+ pythonBuffer.ptr);
+}
- // Make sure the returned nb::buffer_view claims ownership of the data in
- // `pythonBuffer` so it remains valid when Python reads it
- nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
- return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+// This does the opposite transformation of
+// `getBitpackedAttributeFromBooleanBuffer`
+std::unique_ptr<nb_buffer_info>
+PyDenseElementsAttribute::getBooleanBufferFromBitpackedAttribute() {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a numpy array from a MLIR attribute "
+ "is unsupported on big-endian systems");
}
- template <typename Type>
- std::unique_ptr<nb_buffer_info>
- bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- // Prepare the data for the buffer_info.
- // Buffer is configured for read-only access below.
- Type *data = static_cast<Type *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- // Prepare the shape for the buffer_info.
- SmallVector<intptr_t, 4> shape;
- for (intptr_t i = 0; i < rank; ++i)
- shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
- // Prepare the strides for the buffer_info.
- SmallVector<intptr_t, 4> strides;
- if (mlirDenseElementsAttrIsSplat(*this)) {
- // Splats are special, only the single value is stored.
- strides.assign(rank, 0);
- } else {
- for (intptr_t i = 1; i < rank; ++i) {
- intptr_t strideFactor = 1;
- for (intptr_t j = i; j < rank; ++j)
- strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
- strides.push_back(sizeof(Type) * strideFactor);
- }
- strides.push_back(sizeof(Type));
- }
- const char *format;
- if (explicitFormat) {
- format = explicitFormat;
- } else {
- format = nb_format_descriptor<Type>::format();
+ int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+ int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
+ uint8_t *bitpackedData = static_cast<uint8_t *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
+ /*data=*/bitpackedData,
+ /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
+
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object unpackbitsFunc = numpy.attr("unpackbits");
+ nb::object equalFunc = numpy.attr("equal");
+ nb::object reshapeFunc = numpy.attr("reshape");
+ nb::object unpackedBooleans =
+ unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
+
+ // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
+ // We need to:
+ // 1. Slice away the padded bits
+ // 2. Make the boolean array have the correct shape
+ // 3. Convert the array to a boolean array
+ unpackedBooleans = unpackedBooleans[nb::slice(
+ nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
+ unpackedBooleans = equalFunc(unpackedBooleans, 1);
+
+ MlirType shapedType = mlirAttributeGetType(*this);
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ std::vector<intptr_t> shape(rank);
+ for (intptr_t i = 0; i < rank; ++i) {
+ shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
+ }
+ unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+
+ // Make sure the returned nb::buffer_view claims ownership of the data in
+ // `pythonBuffer` so it remains valid when Python reads it
+ nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
+ return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+}
+
+template <typename Type>
+std::unique_ptr<nb_buffer_info>
+PyDenseElementsAttribute::bufferInfo(MlirType shapedType,
+ const char *explicitFormat) {
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ // Prepare the data for the buffer_info.
+ // Buffer is configured for read-only access below.
+ Type *data = static_cast<Type *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ // Prepare the shape for the buffer_info.
+ SmallVector<intptr_t, 4> shape;
+ for (intptr_t i = 0; i < rank; ++i)
+ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+ // Prepare the strides for the buffer_info.
+ SmallVector<intptr_t, 4> strides;
+ if (mlirDenseElementsAttrIsSplat(*this)) {
+ // Splats are special, only the single value is stored.
+ strides.assign(rank, 0);
+ } else {
+ for (intptr_t i = 1; i < rank; ++i) {
+ intptr_t strideFactor = 1;
+ for (intptr_t j = i; j < rank; ++j)
+ strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ strides.push_back(sizeof(Type) * strideFactor);
}
- return std::make_unique<nb_buffer_info>(
- data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
- /*readonly=*/true);
+ strides.push_back(sizeof(Type));
}
-}; // namespace
+ const char *format;
+ if (explicitFormat) {
+ format = explicitFormat;
+ } else {
+ format = nb_format_descriptor<Type>::format();
+ }
+ return std::make_unique<nb_buffer_info>(data, sizeof(Type), format, rank,
+ std::move(shape), std::move(strides),
+ /*readonly=*/true);
+}
PyType_Slot PyDenseElementsAttribute::slots[] = {
// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
@@ -1312,9 +1080,8 @@ PyType_Slot PyDenseElementsAttribute::slots[] = {
{0, nullptr},
};
-/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
- Py_buffer *view,
- int flags) {
+int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, Py_buffer *view,
+ int flags) {
view->obj = nullptr;
std::unique_ptr<nb_buffer_info> info;
try {
@@ -1348,85 +1115,71 @@ PyType_Slot PyDenseElementsAttribute::slots[] = {
return 0;
}
-/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
- Py_buffer *view) {
+void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, Py_buffer *view) {
delete reinterpret_cast<nb_buffer_info *>(view->internal);
}
-/// Refinement of the PyDenseElementsAttribute for attributes containing
-/// integer (and boolean) values. Supports element access.
-class PyDenseIntElementsAttribute
- : public PyConcreteAttribute<PyDenseIntElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
- static constexpr const char *pyClassName = "DenseIntElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- /// Returns the element at the given linear position. Asserts if the index
- /// is out of range.
- nb::object dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
- }
+nb::object PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds element");
+ }
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Index type can also appear as a DenseIntElementsAttr and therefore can be
- // casted to integer.
- assert(mlirTypeIsAInteger(type) ||
- mlirTypeIsAIndex(type) && "expected integer/index element type in "
- "dense int elements attribute");
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::int_ is implicitly constructible
- // from any C++ integral type and handles bitwidth correctly.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAIndex(type)) {
- return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Index type can also appear as a DenseIntElementsAttr and therefore can be
+ // casted to integer.
+ assert(mlirTypeIsAInteger(type) ||
+ mlirTypeIsAIndex(type) && "expected integer/index element type in "
+ "dense int elements attribute");
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nb::int_ is implicitly constructible
+ // from any C++ integral type and handles bitwidth correctly.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAIndex(type)) {
+ return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+ }
+ unsigned width = mlirIntegerTypeGetWidth(type);
+ bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+ if (isUnsigned) {
+ if (width == 1) {
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
- unsigned width = mlirIntegerTypeGetWidth(type);
- bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
- if (isUnsigned) {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
- }
- } else {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
- }
+ if (width == 8) {
+ return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
+ }
+ if (width == 16) {
+ return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
+ }
+ if (width == 32) {
+ return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
+ }
+ if (width == 64) {
+ return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
+ }
+ } else {
+ if (width == 1) {
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
+ }
+ if (width == 8) {
+ return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
+ }
+ if (width == 16) {
+ return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
+ }
+ if (width == 32) {
+ return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
+ }
+ if (width == 64) {
+ return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
}
- throw nb::type_error("Unsupported integer type");
}
+ throw nb::type_error("Unsupported integer type");
+}
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
- }
-};
+void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+}
// Check if the python version is less than 3.13. Py_IsFinalizing is a part
// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
@@ -1434,279 +1187,223 @@ class PyDenseIntElementsAttribute
#define Py_IsFinalizing _Py_IsFinalizing
#endif
-class PyDenseResourceElementsAttribute
- : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction =
- mlirAttributeIsADenseResourceElements;
- static constexpr const char *pyClassName = "DenseResourceElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseResourceElementsAttribute
- getFromBuffer(const nb_buffer &buffer, const std::string &name,
- const PyType &type, std::optional<size_t> alignment,
- bool isMutable, DefaultingPyMlirContext contextWrapper) {
- if (!mlirTypeIsAShaped(type)) {
- throw std::invalid_argument(
- "Constructing a DenseResourceElementsAttr requires a ShapedType.");
- }
-
- // Do not request any conversions as we must ensure to use caller
- // managed memory.
- int flags = PyBUF_STRIDES;
- std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
- if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
- throw nb::python_error();
- }
+PyDenseResourceElementsAttribute
+PyDenseResourceElementsAttribute::getFromBuffer(
+ const nb_buffer &buffer, const std::string &name, const PyType &type,
+ std::optional<size_t> alignment, bool isMutable,
+ DefaultingPyMlirContext contextWrapper) {
+ if (!mlirTypeIsAShaped(type)) {
+ throw std::invalid_argument(
+ "Constructing a DenseResourceElementsAttr requires a ShapedType.");
+ }
- // This scope releaser will only release if we haven't yet transferred
- // ownership.
- auto freeBuffer = llvm::make_scope_exit([&]() {
- if (view)
- PyBuffer_Release(view.get());
- });
+ // Do not request any conversions as we must ensure to use caller
+ // managed memory.
+ int flags = PyBUF_STRIDES;
+ std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
+ if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
+ throw nb::python_error();
+ }
- if (!PyBuffer_IsContiguous(view.get(), 'A')) {
- throw std::invalid_argument("Contiguous buffer is required.");
- }
+ // This scope releaser will only release if we haven't yet transferred
+ // ownership.
+ auto freeBuffer = llvm::make_scope_exit([&]() {
+ if (view)
+ PyBuffer_Release(view.get());
+ });
- // Infer alignment to be the stride of one element if not explicit.
- size_t inferredAlignment;
- if (alignment)
- inferredAlignment = *alignment;
- else
- inferredAlignment = view->strides[view->ndim - 1];
-
- // The userData is a Py_buffer* that the deleter owns.
- auto deleter = [](void *userData, const void *data, size_t size,
- size_t align) {
- if (Py_IsFinalizing())
- return;
- assert(Py_IsInitialized() && "expected interpreter to be initialized");
- Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
- nb::gil_scoped_acquire gil;
- PyBuffer_Release(ownedView);
- delete ownedView;
- };
-
- size_t rawBufferSize = view->len;
- MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
- type, toMlirStringRef(name), view->buf, rawBufferSize,
- inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseResourceElementsAttr could not be constructed from the given "
- "buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- view.release();
- return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+ if (!PyBuffer_IsContiguous(view.get(), 'A')) {
+ throw std::invalid_argument("Contiguous buffer is required.");
}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
- nb::arg("array"), nb::arg("name"), nb::arg("type"),
- nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
- nb::arg("context").none() = nb::none(),
- kDenseResourceElementsAttrGetFromBufferDocstring);
- }
-};
+ // Infer alignment to be the stride of one element if not explicit.
+ size_t inferredAlignment;
+ if (alignment)
+ inferredAlignment = *alignment;
+ else
+ inferredAlignment = view->strides[view->ndim - 1];
+
+ // The userData is a Py_buffer* that the deleter owns.
+ auto deleter = [](void *userData, const void *data, size_t size,
+ size_t align) {
+ if (Py_IsFinalizing())
+ return;
+ assert(Py_IsInitialized() && "expected interpreter to be initialized");
+ Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
+ nb::gil_scoped_acquire gil;
+ PyBuffer_Release(ownedView);
+ delete ownedView;
+ };
-class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
- static constexpr const char *pyClassName = "DictAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirDictionaryAttrGetTypeID;
+ size_t rawBufferSize = view->len;
+ MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
+ type, toMlirStringRef(name), view->buf, rawBufferSize, inferredAlignment,
+ isMutable, deleter, static_cast<void *>(view.get()));
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseResourceElementsAttr could not be constructed from the given "
+ "buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
+ }
+ view.release();
+ return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+}
- intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+void PyDenseResourceElementsAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
+ nb::arg("array"), nb::arg("name"), nb::arg("type"),
+ nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
+ nb::arg("context").none() = nb::none(),
+ kDenseResourceElementsAttrGetFromBufferDocstring);
+}
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(
- mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
- }
+intptr_t PyDictAttribute::dunderLen() {
+ return mlirDictionaryAttrGetNumElements(*this);
+}
- static void bindDerived(ClassTy &c) {
- c.def("__contains__", &PyDictAttribute::dunderContains);
- c.def("__len__", &PyDictAttribute::dunderLen);
- c.def_static(
- "get",
- [](const nb::dict &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirNamedAttribute> mlirNamedAttributes;
- mlirNamedAttributes.reserve(attributes.size());
- for (std::pair<nb::handle, nb::handle> it : attributes) {
- auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
- auto name = nb::cast<std::string>(it.first);
- mlirNamedAttributes.push_back(mlirNamedAttributeGet(
- mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
- toMlirStringRef(name)),
- mlirAttr));
- }
- MlirAttribute attr =
- mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
- mlirNamedAttributes.data());
- return PyDictAttribute(context->getRef(), attr);
- },
- nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
- "Gets an uniqued dict attribute");
- c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr))
- throw nb::key_error("attempt to access a non-existent attribute");
- return attr;
- });
- c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
- if (index < 0 || index >= self.dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data));
- });
- }
-};
+bool PyDictAttribute::dunderContains(const std::string &name) {
+ return !mlirAttributeIsNull(
+ mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
+}
-/// Refinement of PyDenseElementsAttribute for attributes containing
-/// floating-point values. Supports element access.
-class PyDenseFPElementsAttribute
- : public PyConcreteAttribute<PyDenseFPElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
- static constexpr const char *pyClassName = "DenseFPElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- nb::float_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
+void PyDictAttribute::bindDerived(ClassTy &c) {
+ c.def("__contains__", &PyDictAttribute::dunderContains);
+ c.def("__len__", &PyDictAttribute::dunderLen);
+ c.def_static(
+ "get",
+ [](const nb::dict &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+ mlirNamedAttributes.reserve(attributes.size());
+ for (std::pair<nb::handle, nb::handle> it : attributes) {
+ auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
+ auto name = nb::cast<std::string>(it.first);
+ mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+ mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
+ toMlirStringRef(name)),
+ mlirAttr));
+ }
+ MlirAttribute attr =
+ mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+ mlirNamedAttributes.data());
+ return PyDictAttribute(context->getRef(), attr);
+ },
+ nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
+ "Gets an uniqued dict attribute");
+ c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
+ MlirAttribute attr =
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr))
+ throw nb::key_error("attempt to access a non-existent attribute");
+ return attr;
+ });
+ c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+ if (index < 0 || index >= self.dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds attribute");
}
+ MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data));
+ });
+}
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::float_ is implicitly constructible
- // from float and double.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAF32(type)) {
- return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
- }
- if (mlirTypeIsAF64(type)) {
- return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
- }
- throw nb::type_error("Unsupported floating-point type");
+nb::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds element");
}
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nb::float_ is implicitly constructible
+ // from float and double.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAF32(type)) {
+ return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
}
-};
-
-class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
- static constexpr const char *pyClassName = "TypeAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTypeAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const PyType &value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirTypeAttrGet(value.get());
- return PyTypeAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context").none() = nb::none(),
- "Gets a uniqued Type attribute");
- c.def_prop_ro("value", [](PyTypeAttribute &self) {
- return mlirTypeAttrGetValue(self.get());
- });
+ if (mlirTypeIsAF64(type)) {
+ return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
}
-};
+ throw nb::type_error("Unsupported floating-point type");
+}
-/// Unit Attribute subclass. Unit attributes don't have values.
-class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
- static constexpr const char *pyClassName = "UnitAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnitAttrGetTypeID;
+void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- return PyUnitAttribute(context->getRef(),
- mlirUnitAttrGet(context->get()));
- },
- nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
- }
-};
+void PyTypeAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirTypeAttrGet(value.get());
+ return PyTypeAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued Type attribute");
+ c.def_prop_ro("value", [](PyTypeAttribute &self) {
+ return mlirTypeAttrGetValue(self.get());
+ });
+}
-/// Strided layout attribute subclass.
-class PyStridedLayoutAttribute
- : public PyConcreteAttribute<PyStridedLayoutAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
- static constexpr const char *pyClassName = "StridedLayoutAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirStridedLayoutAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](int64_t offset, const std::vector<int64_t> &strides,
- DefaultingPyMlirContext ctx) {
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), offset, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("offset"), nb::arg("strides"),
- nb::arg("context").none() = nb::none(),
- "Gets a strided layout attribute.");
- c.def_static(
- "get_fully_dynamic",
- [](int64_t rank, DefaultingPyMlirContext ctx) {
- auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
- std::vector<int64_t> strides(rank);
- llvm::fill(strides, dynamic);
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), dynamic, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("rank"), nb::arg("context").none() = nb::none(),
- "Gets a strided layout attribute with dynamic offset and strides of "
- "a "
- "given rank.");
- c.def_prop_ro(
- "offset",
- [](PyStridedLayoutAttribute &self) {
- return mlirStridedLayoutAttrGetOffset(self);
- },
- "Returns the value of the float point attribute");
- c.def_prop_ro(
- "strides",
- [](PyStridedLayoutAttribute &self) {
- intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
- std::vector<int64_t> strides(size);
- for (intptr_t i = 0; i < size; i++) {
- strides[i] = mlirStridedLayoutAttrGetStride(self, i);
- }
- return strides;
- },
- "Returns the value of the float point attribute");
- }
-};
+void PyUnitAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return PyUnitAttribute(context->getRef(),
+ mlirUnitAttrGet(context->get()));
+ },
+ nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
+}
-nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
+void PyStridedLayoutAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int64_t offset, const std::vector<int64_t> &strides,
+ DefaultingPyMlirContext ctx) {
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), offset, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nb::arg("offset"), nb::arg("strides"),
+ nb::arg("context").none() = nb::none(),
+ "Gets a strided layout attribute.");
+ c.def_static(
+ "get_fully_dynamic",
+ [](int64_t rank, DefaultingPyMlirContext ctx) {
+ auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
+ std::vector<int64_t> strides(rank);
+ llvm::fill(strides, dynamic);
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), dynamic, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nb::arg("rank"), nb::arg("context").none() = nb::none(),
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
+ "given rank.");
+ c.def_prop_ro(
+ "offset",
+ [](PyStridedLayoutAttribute &self) {
+ return mlirStridedLayoutAttrGetOffset(self);
+ },
+ "Returns the value of the float point attribute");
+ c.def_prop_ro(
+ "strides",
+ [](PyStridedLayoutAttribute &self) {
+ intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+ std::vector<int64_t> strides(size);
+ for (intptr_t i = 0; i < size; i++) {
+ strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+ }
+ return strides;
+ },
+ "Returns the value of the float point attribute");
+}
+
+static nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
@@ -1727,7 +1424,8 @@ nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
+static nb::object
+denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
@@ -1739,7 +1437,7 @@ nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
+static nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
if (PyBoolAttribute::isaFunction(pyAttribute))
return nb::cast(PyBoolAttribute(pyAttribute));
if (PyIntegerAttribute::isaFunction(pyAttribute))
@@ -1750,7 +1448,8 @@ nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
+static nb::object
+symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
if (PySymbolRefAttribute::isaFunction(pyAttribute))
@@ -1761,9 +1460,7 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-} // namespace
-
-void mlir::python::populateIRAttributes(nb::module_ &m) {
+void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
@@ -1816,3 +1513,4 @@ void mlir::python::populateIRAttributes(nb::module_ &m) {
PyStridedLayoutAttribute::bind(m);
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 15889ddabd2c4..cbc0ea5039935 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,21 +6,24 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRModule.h"
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
+// clang-format off
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
+// clang-format on
+
#include <optional>
namespace nb = nanobind;
@@ -1656,81 +1659,47 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
-namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = nb::class_<DerivedTy, PyValue>;
- using IsAFunctionTy = bool (*)(MlirValue);
-
- PyConcreteValue() = default;
- PyConcreteValue(PyOperationRef operationRef, MlirValue value)
- : PyValue(operationRef, value) {}
- PyConcreteValue(PyValue &orig)
- : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
- /// Attempts to cast the original value to the derived type and throws on
- /// type mismatches.
- static MlirValue castFrom(PyValue &orig) {
- if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
- throw nb::value_error((Twine("Cannot cast value to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str()
- .c_str());
- }
- return orig.get();
- }
-
- /// Binds the Python module objects to functions of this class.
- static void bind(nb::module_ &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
- cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
- cls.def_static(
- "isinstance",
- [](PyValue &otherValue) -> bool {
- return DerivedTy::isaFunction(otherValue);
- },
- nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) { return self.maybeDownCast(); });
- DerivedTy::bindDerived(cls);
+MlirValue PyConcreteValue<DerivedTy>::castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+ throw nb::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " + origRepr + ")")
+ .str()
+ .c_str());
}
+ return orig.get();
+}
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-} // namespace
-
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
- static constexpr const char *pyClassName = "OpResult";
- using PyConcreteValue::PyConcreteValue;
+template <typename DerivedTy>
+void PyConcreteValue<DerivedTy>::bind(nb::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nb::arg("other_value"));
+ cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) { return self.maybeDownCast(); });
+ DerivedTy::bindDerived(cls);
+}
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro("owner", [](PyOpResult &self) {
- assert(
- mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in the IR");
- return self.getParentOperation().getObject();
- });
- c.def_prop_ro("result_number", [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- });
- }
-};
+template <typename DerivedTy>
+void PyConcreteValue<DerivedTy>::bindDerived(ClassTy &m) {}
+
+void PyOpResult::bindDerived(ClassTy &c) {
+ c.def_prop_ro("owner", [](PyOpResult &self) {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in the IR");
+ return self.getParentOperation().getObject();
+ });
+ c.def_prop_ro("result_number", [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ });
+}
/// Returns the list of types of the values held by container.
template <typename Container>
@@ -2460,32 +2429,23 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-namespace {
-
-/// Python wrapper for MlirBlockArgument.
-class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
- static constexpr const char *pyClassName = "BlockArgument";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro("owner", [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- });
- c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- });
- c.def(
- "set_type",
- [](PyBlockArgument &self, PyType type) {
- return mlirBlockArgumentSetType(self.get(), type);
- },
- nb::arg("type"));
- }
-};
+void PyBlockArgument::bindDerived(ClassTy &c) {
+ c.def_prop_ro("owner", [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ });
+ c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ });
+ c.def(
+ "set_type",
+ [](PyBlockArgument &self, PyType type) {
+ return mlirBlockArgumentSetType(self.get(), type);
+ },
+ nb::arg("type"));
+}
+namespace {
/// A list of block arguments. Internally, these are stored as consecutive
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 9e1fedaab5235..2b11513885e32 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,11 +12,11 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRModule.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 0de2f1711829b..c77e37da3ffd4 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -6,16 +6,18 @@
//
//===----------------------------------------------------------------------===//
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRModule.h"
#include <optional>
#include <vector>
-#include "Globals.h"
-#include "NanobindUtils.h"
-#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Globals.h"
+// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
+// clang-format on
namespace nb = nanobind;
using namespace mlir;
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a9b12590188f8..df34bfd6f8ab5 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -7,513 +7,290 @@
//===----------------------------------------------------------------------===//
// clang-format off
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRModule.h"
#include "mlir/Bindings/Python/IRTypes.h"
// clang-format on
#include <optional>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace mlir::python;
using llvm::SmallVector;
using llvm::Twine;
namespace {
-
/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
+int mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
+} // namespace
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerTypeGetTypeID;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context").none() = nb::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context").none() = nb::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context").none() = nb::none(),
- "Create an unsigned integer type");
- c.def_prop_ro(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_prop_ro(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_prop_ro(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_prop_ro(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIndexTypeGetTypeID;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a index type.");
- }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
- static constexpr const char *pyClassName = "FloatType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
- "Returns the width of the floating-point type");
- }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
- : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat4E2M1FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float4E2M1FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
- return PyFloat4E2M1FNType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
- : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E2M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E2M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
- return PyFloat6E2M3FNType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
- : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E3M2FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E3M2FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
- return PyFloat6E3M2FNType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
- : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
- return PyFloat8E4M3FNType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2TypeGet(context->get());
- return PyFloat8E5M2Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3TypeGet(context->get());
- return PyFloat8E4M3Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
- : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
- return PyFloat8E4M3FNUZType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(),
- "Create a float8_e4m3fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
- : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3B11FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
- return PyFloat8E4M3B11FNUZType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(),
- "Create a float8_e4m3b11fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
- : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2FNUZType";
- using PyConcreteType::PyConcreteType;
+namespace mlir::python {
+void PyIntegerType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_signless",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context").none() = nb::none(),
+ "Create a signless integer type");
+ c.def_static(
+ "get_signed",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context").none() = nb::none(),
+ "Create a signed integer type");
+ c.def_static(
+ "get_unsigned",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context").none() = nb::none(),
+ "Create an unsigned integer type");
+ c.def_prop_ro(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+ "Returns the width of the integer type");
+ c.def_prop_ro(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self);
+ },
+ "Returns whether this is a signless integer");
+ c.def_prop_ro(
+ "is_signed",
+ [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); },
+ "Returns whether this is a signed integer");
+ c.def_prop_ro(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self);
+ },
+ "Returns whether this is an unsigned integer");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
- return PyFloat8E5M2FNUZType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(),
- "Create a float8_e5m2fnuz type.");
- }
-};
+void PyIndexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirIndexTypeGet(context->get());
+ return PyIndexType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a index type.");
+}
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E3M4TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E3M4Type";
- using PyConcreteType::PyConcreteType;
+void PyFloatType::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E3M4TypeGet(context->get());
- return PyFloat8E3M4Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
- }
-};
+void PyFloat4E2M1FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+ return PyFloat4E2M1FNType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
+}
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
- : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E8M0FNUTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E8M0FNUType";
- using PyConcreteType::PyConcreteType;
+void PyFloat6E2M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+ return PyFloat6E2M3FNType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
- return PyFloat8E8M0FNUType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(),
- "Create a float8_e8m0fnu type.");
- }
-};
+void PyFloat6E3M2FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+ return PyFloat6E3M2FNType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
+}
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirBFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+ return PyFloat8E4M3FNType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a bf16 type.");
- }
-};
+void PyFloat8E5M2Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2TypeGet(context->get());
+ return PyFloat8E5M2Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
+}
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3TypeGet(context->get());
+ return PyFloat8E4M3Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a f16 type.");
- }
-};
+void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+ return PyFloat8E4M3FNUZType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float8_e4m3fnuz type.");
+}
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatTF32TypeGetTypeID;
- static constexpr const char *pyClassName = "FloatTF32Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+ return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(),
+ "Create a float8_e4m3b11fnuz type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirTF32TypeGet(context->get());
- return PyTF32Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a tf32 type.");
- }
-};
+void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+ return PyFloat8E5M2FNUZType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float8_e5m2fnuz type.");
+}
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat32TypeGetTypeID;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E3M4Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E3M4TypeGet(context->get());
+ return PyFloat8E3M4Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a f32 type.");
- }
-};
+void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+ return PyFloat8E8M0FNUType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a float8_e8m0fnu type.");
+}
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat64TypeGetTypeID;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
+void PyBF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirBF16TypeGet(context->get());
+ return PyBF16Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a bf16 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a f64 type.");
- }
-};
+void PyF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF16TypeGet(context->get());
+ return PyF16Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a f16 type.");
+}
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirNoneTypeGetTypeID;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
+void PyTF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirTF32TypeGet(context->get());
+ return PyTF32Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a tf32 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirNoneTypeGet(context->get());
- return PyNoneType(context->getRef(), t);
- },
- nb::arg("context").none() = nb::none(), "Create a none type.");
- }
-};
+void PyF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF32TypeGet(context->get());
+ return PyF32Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a f32 type.");
+}
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirComplexTypeGetTypeID;
- static constexpr const char *pyClassName = "ComplexType";
- using PyConcreteType::PyConcreteType;
+void PyF64Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF64TypeGet(context->get());
+ return PyF64Type(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a f64 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType) {
- // The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType)) {
- MlirType t = mlirComplexTypeGet(elementType);
- return PyComplexType(elementType.getContext(), t);
- }
- throw nb::value_error(
- (Twine("invalid '") +
- nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
- "' and expected floating point or integer type.")
- .str()
- .c_str());
- },
- "Create a complex type");
- c.def_prop_ro(
- "element_type",
- [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
- "Returns element type.");
- }
-};
+void PyNoneType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirNoneTypeGet(context->get());
+ return PyNoneType(context->getRef(), t);
+ },
+ nb::arg("context").none() = nb::none(), "Create a none type.");
+}
-} // namespace
+void PyComplexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ MlirType t = mlirComplexTypeGet(elementType);
+ return PyComplexType(elementType.getContext(), t);
+ }
+ throw nb::value_error(
+ (Twine("invalid '") +
+ nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
+ "' and expected floating point or integer type.")
+ .str()
+ .c_str());
+ },
+ "Create a complex type");
+ c.def_prop_ro(
+ "element_type",
+ [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
+ "Returns element type.");
+}
// Shaped Type Interface - ShapedType
-void mlir::PyShapedType::bindDerived(ClassTy &c) {
+void PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
@@ -534,7 +311,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasStaticShape(self);
},
- "Returns whether the given shaped type has a static shape.");
+ "Returns whether the given shaped type has a shape.");
c.def(
"is_dynamic_dim",
[](PyShapedType &self, intptr_t dim) -> bool {
@@ -571,7 +348,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
"is_static_size",
[](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
nb::arg("dim_size"),
- "Returns whether the given dimension size indicates a static "
+ "Returns whether the given dimension size indicates a "
"dimension.");
c.def(
"is_dynamic_stride_or_offset",
@@ -615,383 +392,294 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
"shaped types.");
}
-void mlir::PyShapedType::requireHasRank() {
+void PyShapedType::requireHasRank() {
if (!mlirShapedTypeHasRank(*this)) {
throw nb::value_error(
"calling this method requires that the type has a rank.");
}
}
-const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
- mlirTypeIsAShaped;
-
-namespace {
-
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirVectorTypeGetTypeID;
- static constexpr const char *pyClassName = "VectorType";
- using PyConcreteType::PyConcreteType;
+const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
+
+void PyVectorType::bindDerived(ClassTy &c) {
+ c.def_static("get", &PyVectorType::get, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable").none() = nb::none(),
+ nb::arg("scalable_dims").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), "Create a vector type")
+ .def_prop_ro("scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_prop_ro("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
+}
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::get, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable").none() = nb::none(),
- nb::arg("scalable_dims").none() = nb::none(),
- nb::arg("loc").none() = nb::none(), "Create a vector type")
- .def_prop_ro(
- "scalable",
- [](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_prop_ro("scalable_dims", [](MlirType self) {
- std::vector<bool> scalableDims;
- size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
- scalableDims.reserve(rank);
- for (size_t i = 0; i < rank; ++i)
- scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
- return scalableDims;
- });
+PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
}
-private:
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
}
-
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
-
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else {
- type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- }
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
-};
-
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
- : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirRankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "RankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirRankedTensorTypeGetChecked(
- loc, shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding").none() = nb::none(),
- nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
- c.def_prop_ro("encoding",
- [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
- MlirAttribute encoding =
- mlirRankedTensorTypeGetEncoding(self.get());
- if (mlirAttributeIsNull(encoding))
- return std::nullopt;
- return encoding;
- });
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+ } else {
+ type =
+ mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType);
}
-};
-
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
- : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
- "Create a unranked tensor type");
- }
-};
-
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "MemRefType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
- shape.data(), layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout").none() = nb::none(),
- nb::arg("memory_space").none() = nb::none(),
- nb::arg("loc").none() = nb::none(), "Create a memref type")
- .def_prop_ro(
- "layout",
- [](PyMemRefType &self) -> MlirAttribute {
- return mlirMemRefTypeGetLayout(self);
- },
- "The layout of the MemRef type.")
- .def(
- "get_strides_and_offset",
- [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
- std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
- int64_t offset;
- if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
- self, strides.data(), &offset)))
- throw std::runtime_error(
- "Failed to extract strides and offset from memref.");
- return {strides, offset};
- },
- "The strides and offset of the MemRef type.")
- .def_prop_ro(
- "affine_map",
- [](PyMemRefType &self) -> PyAffineMap {
- MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
- return PyAffineMap(self.getContext(), map);
- },
- "The layout of the MemRef type as an affine map.")
- .def_prop_ro(
- "memory_space",
- [](PyMemRefType &self) -> std::optional<MlirAttribute> {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return a;
- },
- "Returns the memory space of the given MemRef type.");
- }
-};
-
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
- : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedMemRefType";
- using PyConcreteType::PyConcreteType;
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyRankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirRankedTensorTypeGetChecked(
+ loc, shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
+ c.def_prop_ro(
+ "encoding", [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
+ MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+ if (mlirAttributeIsNull(encoding))
+ return std::nullopt;
+ return encoding;
+ });
+}
- MlirType t =
- mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
- .def_prop_ro(
- "memory_space",
- [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
- MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return a;
- },
- "Returns the memory space of the given Unranked MemRef type.");
- }
-};
+void PyUnrankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
+ "Create a unranked tensor type");
+}
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTupleTypeGetTypeID;
- static constexpr const char *pyClassName = "TupleType";
- using PyConcreteType::PyConcreteType;
+void PyMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout,
+ PyAttribute *memorySpace, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+ shape.data(), layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout").none() = nb::none(),
+ nb::arg("memory_space").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), "Create a memref type")
+ .def_prop_ro(
+ "layout",
+ [](PyMemRefType &self) -> MlirAttribute {
+ return mlirMemRefTypeGetLayout(self);
+ },
+ "The layout of the MemRef type.")
+ .def(
+ "get_strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+ self, strides.data(), &offset)))
+ throw std::runtime_error(
+ "Failed to extract strides and offset from memref.");
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
+ .def_prop_ro(
+ "affine_map",
+ [](PyMemRefType &self) -> PyAffineMap {
+ MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+ return PyAffineMap(self.getContext(), map);
+ },
+ "The layout of the MemRef type as an affine map.")
+ .def_prop_ro(
+ "memory_space",
+ [](PyMemRefType &self) -> std::optional<MlirAttribute> {
+ MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return a;
+ },
+ "Returns the memory space of the given MemRef type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- elements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context").none() = nb::none(),
- "Create a tuple type");
- c.def(
- "get_type",
- [](PyTupleType &self, intptr_t pos) {
- return mlirTupleTypeGetType(self, pos);
- },
- nb::arg("pos"), "Returns the pos-th type in the tuple type.");
- c.def_prop_ro(
- "num_types",
- [](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self);
- },
- "Returns the number of types contained in a tuple.");
- }
-};
+void PyUnrankedMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t =
+ mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
+ .def_prop_ro(
+ "memory_space",
+ [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return a;
+ },
+ "Returns the memory space of the given Unranked MemRef type.");
+}
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFunctionTypeGetTypeID;
- static constexpr const char *pyClassName = "FunctionType";
- using PyConcreteType::PyConcreteType;
+void PyTupleType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_tuple",
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nb::arg("elements"), nb::arg("context").none() = nb::none(),
+ "Create a tuple type");
+ c.def(
+ "get_type",
+ [](PyTupleType &self, intptr_t pos) {
+ return mlirTupleTypeGetType(self, pos);
+ },
+ nb::arg("pos"), "Returns the pos-th type in the tuple type.");
+ c.def_prop_ro(
+ "num_types",
+ [](PyTupleType &self) -> intptr_t {
+ return mlirTupleTypeGetNumTypes(self);
+ },
+ "Returns the number of types contained in a tuple.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
- DefaultingPyMlirContext context) {
- MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"),
- nb::arg("context").none() = nb::none(),
- "Gets a FunctionType from a list of input and result types");
- c.def_prop_ro(
- "inputs",
- [](PyFunctionType &self) {
- MlirType t = self;
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetInput(t, i));
- }
- return types;
- },
- "Returns the list of input types in the FunctionType.");
- c.def_prop_ro(
- "results",
- [](PyFunctionType &self) {
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetResult(self, i));
- }
- return types;
- },
- "Returns the list of result types in the FunctionType.");
- }
-};
+void PyFunctionType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nb::arg("inputs"), nb::arg("results"),
+ nb::arg("context").none() = nb::none(),
+ "Gets a FunctionType from a list of input and result types");
+ c.def_prop_ro(
+ "inputs",
+ [](PyFunctionType &self) {
+ MlirType t = self;
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetInput(t, i));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ c.def_prop_ro(
+ "results",
+ [](PyFunctionType &self) {
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetResult(self, i));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+}
+}; // namespace mlir::python
-static MlirStringRef toMlirStringRef(const std::string &s) {
+namespace {
+MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
-
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueTypeGetTypeID;
- static constexpr const char *pyClassName = "OpaqueType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const std::string &typeData,
- DefaultingPyMlirContext context) {
- MlirType type = mlirOpaqueTypeGet(context->get(),
- toMlirStringRef(dialectNamespace),
- toMlirStringRef(typeData));
- return PyOpaqueType(context->getRef(), type);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"),
- nb::arg("context").none() = nb::none(),
- "Create an unregistered (opaque) dialect type.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque type as a string.");
- c.def_prop_ro(
- "data",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaque type as a string.");
- }
-};
-
} // namespace
+namespace mlir::python {
+void PyOpaqueType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const std::string &typeData,
+ DefaultingPyMlirContext context) {
+ MlirType type =
+ mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace),
+ toMlirStringRef(typeData));
+ return PyOpaqueType(context->getRef(), type);
+ },
+ nb::arg("dialect_namespace"), nb::arg("buffer"),
+ nb::arg("context").none() = nb::none(),
+ "Create an unregistered (opaque) dialect type.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque type as a string.");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaque type as a string.");
+}
+} // namespace mlir::python
+
void mlir::python::populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..0be68e730e186 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRModule.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..1f699930565fd 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,10 +8,12 @@
#include "Pass.h"
-#include "IRModule.h"
#include "mlir-c/Pass.h"
+#include "mlir/Bindings/Python/IRModule.h"
+// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
+// clang-format on
namespace nb = nanobind;
using namespace nb::literals;
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..0221bd10e723e 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
#define MLIR_BINDINGS_PYTHON_PASS_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..28f050bc05562 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,11 +8,13 @@
#include "Rewrite.h"
-#include "IRModule.h"
#include "mlir-c/Rewrite.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "mlir/Bindings/Python/IRModule.h"
#include "mlir/Config/mlir-config.h"
+// clang-format off
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
+// clang-format on
namespace nb = nanobind;
using namespace mlir;
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index ae89e2b9589f1..f8ffdc7bdc458 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..56327cbe4a463 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -488,10 +488,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
Rewrite.cpp
# Headers must be included explicitly so they are installed.
- Globals.h
- IRModule.h
Pass.h
- NanobindUtils.h
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
@@ -698,8 +695,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectSMT.cpp
- # Headers must be included explicitly so they are installed.
- NanobindUtils.h
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
>From 28cd226227413fc5f27bc6a36573b48b636457c5 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Thu, 28 Aug 2025 02:59:43 -0400
Subject: [PATCH 2/3] Update all_requirements.txt
---
.ci/all_requirements.txt | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/.ci/all_requirements.txt b/.ci/all_requirements.txt
index f73500efdc7e0..4a203444c404c 100644
--- a/.ci/all_requirements.txt
+++ b/.ci/all_requirements.txt
@@ -34,9 +34,7 @@ ml-dtypes==0.5.1 ; python_version < "3.13" \
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
# via -r ./mlir/python/requirements.txt
-nanobind==2.7.0 \
- --hash=sha256:73b12d0e751d140d6c1bf4b215e18818a8debfdb374f08dc3776ad208d808e74 \
- --hash=sha256:f9f1b160580c50dcf37b6495a0fd5ec61dc0d95dae5f8004f87dd9ad7eb46b34
+nanobind @ git+https://github.com/wjakob/nanobind
# via -r ./mlir/python/requirements.txt
numpy==2.0.2 \
--hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \
>From 0be37a40b145ca0bebf3d0d583de1fa944264475 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Thu, 28 Aug 2025 03:13:16 -0400
Subject: [PATCH 3/3] Update all_requirements.txt
---
.ci/all_requirements.txt | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/.ci/all_requirements.txt b/.ci/all_requirements.txt
index 4a203444c404c..f73500efdc7e0 100644
--- a/.ci/all_requirements.txt
+++ b/.ci/all_requirements.txt
@@ -34,7 +34,9 @@ ml-dtypes==0.5.1 ; python_version < "3.13" \
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
# via -r ./mlir/python/requirements.txt
-nanobind @ git+https://github.com/wjakob/nanobind
+nanobind==2.7.0 \
+ --hash=sha256:73b12d0e751d140d6c1bf4b215e18818a8debfdb374f08dc3776ad208d808e74 \
+ --hash=sha256:f9f1b160580c50dcf37b6495a0fd5ec61dc0d95dae5f8004f87dd9ad7eb46b34
# via -r ./mlir/python/requirements.txt
numpy==2.0.2 \
--hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \
More information about the Mlir-commits
mailing list