[llvm-branch-commits] [mlir] [mlir][Python] move IRTypes and IRAttributes to MLIRPythonSupport (PR #174118)

Maksim Levental via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 2 13:02:09 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/174118

>From 8bc6e7931ff4b61089581b8960717c0b0d2a8afb Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 14:20:39 -0800
Subject: [PATCH] [mlir][Python] move IRTypes and IRAttributes to public
 headers

---
 .../mlir/Bindings/Python/IRAttributes.h       |  593 +++++
 mlir/include/mlir/Bindings/Python/IRCore.h    |   17 +-
 mlir/include/mlir/Bindings/Python/IRTypes.h   |  393 ++-
 mlir/lib/Bindings/Python/IRAttributes.cpp     | 2297 +++++++----------
 mlir/lib/Bindings/Python/IRTypes.cpp          | 1609 +++++-------
 mlir/lib/Bindings/Python/MainModule.cpp       |    2 +
 mlir/python/CMakeLists.txt                    |    6 +-
 .../python/lib/PythonTestModuleNanobind.cpp   |  131 +-
 8 files changed, 2632 insertions(+), 2416 deletions(-)
 create mode 100644 mlir/include/mlir/Bindings/Python/IRAttributes.h

diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
new file mode 100644
index 0000000000000..d64e32037664c
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -0,0 +1,593 @@
+//===- IRAttributes.h - Exports builtin and standard attributes -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+#define MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+
+#include <optional>
+#include <string>
+#include <string_view>
+#include <utility>
+
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+struct nb_buffer_info {
+  void *ptr = nullptr;
+  ssize_t itemsize = 0;
+  ssize_t size = 0;
+  const char *format = nullptr;
+  ssize_t ndim = 0;
+  SmallVector<ssize_t, 4> shape;
+  SmallVector<ssize_t, 4> strides;
+  bool readonly = false;
+
+  nb_buffer_info(
+      void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+      SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+      bool readonly = false,
+      std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
+          std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr));
+
+  explicit nb_buffer_info(Py_buffer *view)
+      : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+                       {view->shape, view->shape + view->ndim},
+                       // TODO(phawkins): check for null strides
+                       {view->strides, view->strides + view->ndim},
+                       view->readonly != 0,
+                       std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
+                           view, PyBuffer_Release)) {}
+
+  nb_buffer_info(const nb_buffer_info &) = delete;
+  nb_buffer_info(nb_buffer_info &&) = default;
+  nb_buffer_info &operator=(const nb_buffer_info &) = delete;
+  nb_buffer_info &operator=(nb_buffer_info &&) = default;
+
+private:
+  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
+};
+
+class MLIR_PYTHON_API_EXPORTED nb_buffer : public nanobind::object {
+  NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
+
+  nb_buffer_info request() const;
+};
+
+template <typename T>
+struct nb_format_descriptor {};
+
+class MLIR_PYTHON_API_EXPORTED PyAffineMapAttribute
+    : public PyConcreteAttribute<PyAffineMapAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
+  static constexpr const char *pyClassName = "AffineMapAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirAffineMapAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerSetAttribute
+    : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+  static constexpr const char *pyClassName = "IntegerSetAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIntegerSetAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
+template <typename T>
+static T pyTryCast(nanobind::handle object) {
+  try {
+    return nanobind::cast<T>(object);
+  } catch (nanobind::cast_error &err) {
+    std::string msg = std::string("Invalid attribute when attempting to "
+                                  "create an ArrayAttribute (") +
+                      err.what() + ")";
+    throw std::runtime_error(msg.c_str());
+  } catch (std::runtime_error &err) {
+    std::string msg = std::string("Invalid attribute (None?) when attempting "
+                                  "to create an ArrayAttribute (") +
+                      err.what() + ")";
+    throw std::runtime_error(msg.c_str());
+  }
+}
+
+/// A python-wrapped dense array attribute with an element type and a derived
+/// implementation class.
+template <typename EltTy, typename DerivedT>
+class MLIR_PYTHON_API_EXPORTED PyDenseArrayAttribute
+    : public PyConcreteAttribute<DerivedT> {
+public:
+  using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
+
+  /// Iterator over the integer elements of a dense array.
+  class PyDenseArrayIterator {
+  public:
+    PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+    /// Return a copy of the iterator.
+    PyDenseArrayIterator dunderIter() { return *this; }
+
+    /// Return the next element.
+    EltTy dunderNext() {
+      // Throw if the index has reached the end.
+      if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
+        throw nanobind::stop_iteration();
+      return DerivedT::getElement(attr.get(), nextIndex++);
+    }
+
+    /// Bind the iterator class.
+    static void bind(nanobind::module_ &m) {
+      nanobind::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
+          .def("__iter__", &PyDenseArrayIterator::dunderIter)
+          .def("__next__", &PyDenseArrayIterator::dunderNext);
+    }
+
+  private:
+    /// The referenced dense array attribute.
+    PyAttribute attr;
+    /// The next index to read.
+    int nextIndex = 0;
+  };
+
+  /// Get the element at the given index.
+  EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
+
+  /// Bind the attribute class.
+  static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
+    // Bind the constructor.
+    if constexpr (std::is_same_v<EltTy, bool>) {
+      c.def_static(
+          "get",
+          [](const nanobind::sequence &py_values, DefaultingPyMlirContext ctx) {
+            std::vector<bool> values;
+            for (nanobind::handle py_value : py_values) {
+              int is_true = PyObject_IsTrue(py_value.ptr());
+              if (is_true < 0) {
+                throw nanobind::python_error();
+              }
+              values.push_back(is_true);
+            }
+            return getAttribute(values, ctx->getRef());
+          },
+          nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+          "Gets a uniqued dense array attribute");
+    } else {
+      c.def_static(
+          "get",
+          [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
+            return getAttribute(values, ctx->getRef());
+          },
+          nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+          "Gets a uniqued dense array attribute");
+    }
+    // Bind the array methods.
+    c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
+      if (i >= mlirDenseArrayGetNumElements(arr))
+        throw nanobind::index_error("DenseArray index out of range");
+      return arr.getItem(i);
+    });
+    c.def("__len__", [](const DerivedT &arr) {
+      return mlirDenseArrayGetNumElements(arr);
+    });
+    c.def("__iter__",
+          [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
+    c.def("__add__", [](DerivedT &arr, const nanobind::list &extras) {
+      std::vector<EltTy> values;
+      intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
+      values.reserve(numOldElements + nanobind::len(extras));
+      for (intptr_t i = 0; i < numOldElements; ++i)
+        values.push_back(arr.getItem(i));
+      for (nanobind::handle attr : extras)
+        values.push_back(pyTryCast<EltTy>(attr));
+      return getAttribute(values, arr.getContext());
+    });
+  }
+
+private:
+  static DerivedT getAttribute(const std::vector<EltTy> &values,
+                               PyMlirContextRef ctx) {
+    if constexpr (std::is_same_v<EltTy, bool>) {
+      std::vector<int> intValues(values.begin(), values.end());
+      MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
+                                                  intValues.data());
+      return DerivedT(ctx, attr);
+    } else {
+      MlirAttribute attr =
+          DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+      return DerivedT(ctx, attr);
+    }
+  }
+};
+
+/// Instantiate the python dense array classes.
+struct PyDenseBoolArrayAttribute
+    : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
+  static constexpr auto getAttribute = mlirDenseBoolArrayGet;
+  static constexpr auto getElement = mlirDenseBoolArrayGetElement;
+  static constexpr const char *pyClassName = "DenseBoolArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI8ArrayAttribute
+    : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
+  static constexpr auto getAttribute = mlirDenseI8ArrayGet;
+  static constexpr auto getElement = mlirDenseI8ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI8ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI16ArrayAttribute
+    : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
+  static constexpr auto getAttribute = mlirDenseI16ArrayGet;
+  static constexpr auto getElement = mlirDenseI16ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI16ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI32ArrayAttribute
+    : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
+  static constexpr auto getAttribute = mlirDenseI32ArrayGet;
+  static constexpr auto getElement = mlirDenseI32ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI32ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI64ArrayAttribute
+    : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
+  static constexpr auto getAttribute = mlirDenseI64ArrayGet;
+  static constexpr auto getElement = mlirDenseI64ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI64ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF32ArrayAttribute
+    : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
+  static constexpr auto getAttribute = mlirDenseF32ArrayGet;
+  static constexpr auto getElement = mlirDenseF32ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseF32ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF64ArrayAttribute
+    : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
+  static constexpr auto getAttribute = mlirDenseF64ArrayGet;
+  static constexpr auto getElement = mlirDenseF64ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseF64ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyArrayAttribute
+    : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+  static constexpr const char *pyClassName = "ArrayAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirArrayAttrGetTypeID;
+
+  class PyArrayAttributeIterator {
+  public:
+    PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+    PyArrayAttributeIterator &dunderIter() { return *this; }
+
+    nanobind::typed<nanobind::object, PyAttribute> dunderNext();
+
+    static void bind(nanobind::module_ &m);
+
+  private:
+    PyAttribute attr;
+    int nextIndex = 0;
+  };
+
+  MlirAttribute getItem(intptr_t i) const;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Float Point Attribute subclass - FloatAttr.
+class MLIR_PYTHON_API_EXPORTED PyFloatAttribute
+    : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+  static constexpr const char *pyClassName = "FloatAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloatAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class MLIR_PYTHON_API_EXPORTED PyIntegerAttribute
+    : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+  static constexpr const char *pyClassName = "IntegerAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c);
+
+private:
+  static int64_t toPyInt(PyIntegerAttribute &self);
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class MLIR_PYTHON_API_EXPORTED PyBoolAttribute
+    : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+  static constexpr const char *pyClassName = "BoolAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PySymbolRefAttribute
+    : public PyConcreteAttribute<PySymbolRefAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
+  static constexpr const char *pyClassName = "SymbolRefAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
+                                       PyMlirContext &context);
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFlatSymbolRefAttribute
+    : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+  static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyOpaqueAttribute
+    : public PyConcreteAttribute<PyOpaqueAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
+  static constexpr const char *pyClassName = "OpaqueAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirOpaqueAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
+// TODO: Support construction of string elements.
+class MLIR_PYTHON_API_EXPORTED PyDenseElementsAttribute
+    : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+  static constexpr const char *pyClassName = "DenseElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static PyDenseElementsAttribute
+  getFromList(const nanobind::list &attributes,
+              std::optional<PyType> explicitType,
+              DefaultingPyMlirContext contextWrapper);
+
+  static PyDenseElementsAttribute
+  getFromBuffer(const nb_buffer &array, bool signless,
+                const std::optional<PyType> &explicitType,
+                std::optional<std::vector<int64_t>> explicitShape,
+                DefaultingPyMlirContext contextWrapper);
+
+  static PyDenseElementsAttribute getSplat(const PyType &shapedType,
+                                           PyAttribute &elementAttr);
+
+  intptr_t dunderLen() const;
+
+  std::unique_ptr<nb_buffer_info> accessBuffer();
+
+  static void bindDerived(ClassTy &c);
+
+  static PyType_Slot slots[];
+
+private:
+  static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
+  static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+
+  static bool isUnsignedIntegerFormat(std::string_view format);
+
+  static bool isSignedIntegerFormat(std::string_view format);
+
+  static MlirType
+  getShapedType(std::optional<MlirType> bulkLoadElementType,
+                std::optional<std::vector<int64_t>> explicitShape,
+                Py_buffer &view);
+
+  static MlirAttribute getAttributeFromBuffer(
+      Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+      const std::optional<std::vector<int64_t>> &explicitShape,
+      MlirContext &context);
+
+  // There is a complication for boolean numpy arrays, as numpy represents
+  // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
+  // booleans per byte.
+  static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+      Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+      MlirContext &context);
+
+  // This does the opposite transformation of
+  // `getBitpackedAttributeFromBooleanBuffer`
+  std::unique_ptr<nb_buffer_info>
+  getBooleanBufferFromBitpackedAttribute() const;
+
+  template <typename Type>
+  std::unique_ptr<nb_buffer_info>
+  bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
+    intptr_t rank = mlirShapedTypeGetRank(shapedType);
+    // Prepare the data for the buffer_info.
+    // Buffer is configured for read-only access below.
+    Type *data = static_cast<Type *>(
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    // Prepare the shape for the buffer_info.
+    SmallVector<intptr_t, 4> shape;
+    for (intptr_t i = 0; i < rank; ++i)
+      shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+    // Prepare the strides for the buffer_info.
+    SmallVector<intptr_t, 4> strides;
+    if (mlirDenseElementsAttrIsSplat(*this)) {
+      // Splats are special, only the single value is stored.
+      strides.assign(rank, 0);
+    } else {
+      for (intptr_t i = 1; i < rank; ++i) {
+        intptr_t strideFactor = 1;
+        for (intptr_t j = i; j < rank; ++j)
+          strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+        strides.push_back(sizeof(Type) * strideFactor);
+      }
+      strides.push_back(sizeof(Type));
+    }
+    const char *format;
+    if (explicitFormat) {
+      format = explicitFormat;
+    } else {
+      format = nb_format_descriptor<Type>::format();
+    }
+    return std::make_unique<nb_buffer_info>(
+        data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
+        /*readonly=*/true);
+  }
+};
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseIntElementsAttribute
+    : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+                                 PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+  static constexpr const char *pyClassName = "DenseIntElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  /// Returns the element at the given linear position. Asserts if the index
+  /// is out of range.
+  nanobind::int_ dunderGetItem(intptr_t pos) const;
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDenseResourceElementsAttribute
+    : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction =
+      mlirAttributeIsADenseResourceElements;
+  static constexpr const char *pyClassName = "DenseResourceElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static PyDenseResourceElementsAttribute
+  getFromBuffer(const nb_buffer &buffer, const std::string &name,
+                const PyType &type, std::optional<size_t> alignment,
+                bool isMutable, DefaultingPyMlirContext contextWrapper);
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDictAttribute
+    : public PyConcreteAttribute<PyDictAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+  static constexpr const char *pyClassName = "DictAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirDictionaryAttrGetTypeID;
+
+  intptr_t dunderLen() const;
+
+  bool dunderContains(const std::string &name) const;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseFPElementsAttribute
+    : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+                                 PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+  static constexpr const char *pyClassName = "DenseFPElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  nanobind::float_ dunderGetItem(intptr_t pos) const;
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyTypeAttribute
+    : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+  static constexpr const char *pyClassName = "TypeAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTypeAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Unit Attribute subclass. Unit attributes don't have values.
+class MLIR_PYTHON_API_EXPORTED PyUnitAttribute
+    : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+  static constexpr const char *pyClassName = "UnitAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnitAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Strided layout attribute subclass.
+class MLIR_PYTHON_API_EXPORTED PyStridedLayoutAttribute
+    : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+  static constexpr const char *pyClassName = "StridedLayoutAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirStridedLayoutAttrGetTypeID;
+
+  static void bindDerived(ClassTy &c);
+};
+
+MLIR_PYTHON_API_EXPORTED void populateIRAttributes(nanobind::module_ &m);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index d8662137b60e7..53c55b7086ce8 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -989,7 +989,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
       PyGlobals::get().registerTypeCaster(
           DerivedTy::getTypeIdFunction(),
           nanobind::cast<nanobind::callable>(nanobind::cpp_function(
-              [](PyType pyType) -> DerivedTy { return pyType; })));
+              [](PyType pyType) -> DerivedTy { return pyType; })),
+          /*replace*/ true);
     }
 
     DerivedTy::bindDerived(cls);
@@ -1133,7 +1134,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
           nanobind::cast<nanobind::callable>(
               nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
                 return pyAttribute;
-              })));
+              })),
+          /*replace*/ true);
     }
 
     DerivedTy::bindDerived(cls);
@@ -1517,6 +1519,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
   // and redefine bindDerived.
   using ClassTy = nanobind::class_<DerivedTy, PyValue>;
   using IsAFunctionTy = bool (*)(MlirValue);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
 
   PyConcreteValue() = default;
   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
@@ -1559,6 +1563,15 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
         [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
           return self.maybeDownCast();
         });
+
+    if (DerivedTy::getTypeIdFunction) {
+      PyGlobals::get().registerValueCaster(
+          DerivedTy::getTypeIdFunction(),
+          nanobind::cast<nanobind::callable>(nanobind::cpp_function(
+              [](PyValue pyValue) -> DerivedTy { return pyValue; })),
+          /*replace*/ true);
+    }
+
     DerivedTy::bindDerived(cls);
   }
 
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 87e0e10764bd8..a0901fefec5ce 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,13 +9,284 @@
 #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
 #define MLIR_BINDINGS_PYTHON_IRTYPES_H
 
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir-c/BuiltinTypes.h"
 
 namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+MLIR_PYTHON_API_EXPORTED int mlirTypeIsAIntegerOrFloat(MlirType type);
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerType
+    : public PyConcreteType<PyIntegerType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIntegerTypeGetTypeID;
+  static constexpr const char *pyClassName = "IntegerType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Index Type subclass - IndexType.
+class MLIR_PYTHON_API_EXPORTED PyIndexType
+    : public PyConcreteType<PyIndexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIndexTypeGetTypeID;
+  static constexpr const char *pyClassName = "IndexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFloatType
+    : public PyConcreteType<PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+  static constexpr const char *pyClassName = "FloatType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float4E2M1FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat4E2M1FNType
+    : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat4E2M1FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float4E2M1FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E2M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E2M3FNType
+    : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E2M3FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E2M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E3M2FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E3M2FNType
+    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E3M2FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E3M2FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNType
+    : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2Type
+    : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E5M2TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E5M2Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3Type
+    : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNUZType
+    : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3FNUZTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3B11FNUZType
+    : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3B11FNUZTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2FNUZType
+    : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E5M2FNUZTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E3M4Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E3M4Type
+    : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E3M4TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E3M4Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E8M0FNUType
+    : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E8M0FNUTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E8M0FNUType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - BF16Type.
+class MLIR_PYTHON_API_EXPORTED PyBF16Type
+    : public PyConcreteType<PyBF16Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirBFloat16TypeGetTypeID;
+  static constexpr const char *pyClassName = "BF16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F16Type.
+class MLIR_PYTHON_API_EXPORTED PyF16Type
+    : public PyConcreteType<PyF16Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat16TypeGetTypeID;
+  static constexpr const char *pyClassName = "F16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - TF32Type.
+class MLIR_PYTHON_API_EXPORTED PyTF32Type
+    : public PyConcreteType<PyTF32Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloatTF32TypeGetTypeID;
+  static constexpr const char *pyClassName = "FloatTF32Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F32Type.
+class MLIR_PYTHON_API_EXPORTED PyF32Type
+    : public PyConcreteType<PyF32Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat32TypeGetTypeID;
+  static constexpr const char *pyClassName = "F32Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F64Type.
+class MLIR_PYTHON_API_EXPORTED PyF64Type
+    : public PyConcreteType<PyF64Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat64TypeGetTypeID;
+  static constexpr const char *pyClassName = "F64Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// None Type subclass - NoneType.
+class MLIR_PYTHON_API_EXPORTED PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirNoneTypeGetTypeID;
+  static constexpr const char *pyClassName = "NoneType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Complex Type subclass - ComplexType.
+class MLIR_PYTHON_API_EXPORTED PyComplexType
+    : public PyConcreteType<PyComplexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirComplexTypeGetTypeID;
+  static constexpr const char *pyClassName = "ComplexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
 /// Shaped Type Interface - ShapedType
-class MLIR_PYTHON_API_EXPORTED PyShapedType
+class MLIR_PYTHON_API_EXPORTED MLIR_PYTHON_API_EXPORTED PyShapedType
     : public PyConcreteType<PyShapedType> {
 public:
   static const IsAFunctionTy isaFunction;
@@ -27,6 +298,124 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
 private:
   void requireHasRank();
 };
+
+/// Vector Type subclass - VectorType.
+class MLIR_PYTHON_API_EXPORTED PyVectorType
+    : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirVectorTypeGetTypeID;
+  static constexpr const char *pyClassName = "VectorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+
+private:
+  static PyVectorType
+  getChecked(std::vector<int64_t> shape, PyType &elementType,
+             std::optional<nanobind::list> scalable,
+             std::optional<std::vector<int64_t>> scalableDims,
+             DefaultingPyLocation loc);
+
+  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+                          std::optional<nanobind::list> scalable,
+                          std::optional<std::vector<int64_t>> scalableDims,
+                          DefaultingPyMlirContext context);
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyRankedTensorType
+    : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirRankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "RankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedTensorType
+    : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnrankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "UnrankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class MLIR_PYTHON_API_EXPORTED PyMemRefType
+    : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirMemRefTypeGetTypeID;
+  static constexpr const char *pyClassName = "MemRefType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedMemRefType
+    : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnrankedMemRefTypeGetTypeID;
+  static constexpr const char *pyClassName = "UnrankedMemRefType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Tuple Type subclass - TupleType.
+class MLIR_PYTHON_API_EXPORTED PyTupleType
+    : public PyConcreteType<PyTupleType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTupleTypeGetTypeID;
+  static constexpr const char *pyClassName = "TupleType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Function type.
+class MLIR_PYTHON_API_EXPORTED PyFunctionType
+    : public PyConcreteType<PyFunctionType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFunctionTypeGetTypeID;
+  static constexpr const char *pyClassName = "FunctionType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Opaque Type subclass - OpaqueType.
+class MLIR_PYTHON_API_EXPORTED PyOpaqueType
+    : public PyConcreteType<PyOpaqueType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirOpaqueTypeGetTypeID;
+  static constexpr const char *pyClassName = "OpaqueType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+MLIR_PYTHON_API_EXPORTED void populateIRTypes(nanobind::module_ &m);
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index f0f0ae9ba741e..3cd3ce5c4c0ee 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
 #include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
@@ -125,65 +126,29 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 
-struct nb_buffer_info {
-  void *ptr = nullptr;
-  ssize_t itemsize = 0;
-  ssize_t size = 0;
-  const char *format = nullptr;
-  ssize_t ndim = 0;
-  SmallVector<ssize_t, 4> shape;
-  SmallVector<ssize_t, 4> strides;
-  bool readonly = false;
-
-  nb_buffer_info(
-      void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
-      SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
-      bool readonly = false,
-      std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
-          std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
-      : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
-        shape(std::move(shape_in)), strides(std::move(strides_in)),
-        readonly(readonly), owned_view(std::move(owned_view_in)) {
-    size = 1;
-    for (ssize_t i = 0; i < ndim; ++i) {
-      size *= shape[i];
-    }
+nb_buffer_info::nb_buffer_info(
+    void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+    SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+    bool readonly,
+    std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in)
+    : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
+      shape(std::move(shape_in)), strides(std::move(strides_in)),
+      readonly(readonly), owned_view(std::move(owned_view_in)) {
+  size = 1;
+  for (ssize_t i = 0; i < ndim; ++i) {
+    size *= shape[i];
   }
+}
 
-  explicit nb_buffer_info(Py_buffer *view)
-      : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
-                       {view->shape, view->shape + view->ndim},
-                       // TODO(phawkins): check for null strides
-                       {view->strides, view->strides + view->ndim},
-                       view->readonly != 0,
-                       std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
-                           view, PyBuffer_Release)) {}
-
-  nb_buffer_info(const nb_buffer_info &) = delete;
-  nb_buffer_info(nb_buffer_info &&) = default;
-  nb_buffer_info &operator=(const nb_buffer_info &) = delete;
-  nb_buffer_info &operator=(nb_buffer_info &&) = default;
-
-private:
-  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
-};
-
-class nb_buffer : public nb::object {
-  NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
-
-  nb_buffer_info request() const {
-    int flags = PyBUF_STRIDES | PyBUF_FORMAT;
-    auto *view = new Py_buffer();
-    if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
-      delete view;
-      throw nb::python_error();
-    }
-    return nb_buffer_info(view);
+nb_buffer_info nb_buffer::request() const {
+  int flags = PyBUF_STRIDES | PyBUF_FORMAT;
+  auto *view = new Py_buffer();
+  if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
+    delete view;
+    throw nb::python_error();
   }
-};
-
-template <typename T>
-struct nb_format_descriptor {};
+  return nb_buffer_info(view);
+}
 
 template <>
 struct nb_format_descriptor<bool> {
@@ -230,1052 +195,719 @@ struct nb_format_descriptor<double> {
   static const char *format() { return "d"; }
 };
 
-class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
-  static constexpr const char *pyClassName = "AffineMapAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirAffineMapAttrGetTypeID;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyAffineMap &affineMap) {
-          MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
-          return PyAffineMapAttribute(affineMap.getContext(), attr);
-        },
-        nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
-    c.def_prop_ro(
-        "value",
-        [](PyAffineMapAttribute &self) {
-          return PyAffineMap(self.getContext(),
-                             mlirAffineMapAttrGetValue(self));
-        },
-        "Returns the value of the AffineMap attribute");
-  }
-};
+void PyAffineMapAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyAffineMap &affineMap) {
+        MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+        return PyAffineMapAttribute(affineMap.getContext(), attr);
+      },
+      nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+  c.def_prop_ro(
+      "value",
+      [](PyAffineMapAttribute &self) {
+        return PyAffineMap(self.getContext(), mlirAffineMapAttrGetValue(self));
+      },
+      "Returns the value of the AffineMap attribute");
+}
 
-class PyIntegerSetAttribute
-    : public PyConcreteAttribute<PyIntegerSetAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
-  static constexpr const char *pyClassName = "IntegerSetAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirIntegerSetAttrGetTypeID;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyIntegerSet &integerSet) {
-          MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
-          return PyIntegerSetAttribute(integerSet.getContext(), attr);
-        },
-        nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
-  }
-};
+void PyIntegerSetAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyIntegerSet &integerSet) {
+        MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+        return PyIntegerSetAttribute(integerSet.getContext(), attr);
+      },
+      nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+}
 
-template <typename T>
-static T pyTryCast(nb::handle object) {
-  try {
-    return nb::cast<T>(object);
-  } catch (nb::cast_error &err) {
-    std::string msg = std::string("Invalid attribute when attempting to "
-                                  "create an ArrayAttribute (") +
-                      err.what() + ")";
-    throw std::runtime_error(msg.c_str());
-  } catch (std::runtime_error &err) {
-    std::string msg = std::string("Invalid attribute (None?) when attempting "
-                                  "to create an ArrayAttribute (") +
-                      err.what() + ")";
-    throw std::runtime_error(msg.c_str());
-  }
+nb::typed<nb::object, PyAttribute>
+PyArrayAttribute::PyArrayAttributeIterator::dunderNext() {
+  // TODO: Throw is an inefficient way to stop iteration.
+  if (PyArrayAttribute::PyArrayAttributeIterator::nextIndex >=
+      mlirArrayAttrGetNumElements(
+          PyArrayAttribute::PyArrayAttributeIterator::attr.get()))
+    throw nb::stop_iteration();
+  return PyAttribute(
+             this->PyArrayAttribute::PyArrayAttributeIterator::attr
+                 .getContext(),
+             mlirArrayAttrGetElement(
+                 PyArrayAttribute::PyArrayAttributeIterator::attr.get(),
+                 PyArrayAttribute::PyArrayAttributeIterator::nextIndex++))
+      .maybeDownCast();
 }
 
-/// A python-wrapped dense array attribute with an element type and a derived
-/// implementation class.
-template <typename EltTy, typename DerivedT>
-class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
-public:
-  using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
-
-  /// Iterator over the integer elements of a dense array.
-  class PyDenseArrayIterator {
-  public:
-    PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
-    /// Return a copy of the iterator.
-    PyDenseArrayIterator dunderIter() { return *this; }
-
-    /// Return the next element.
-    EltTy dunderNext() {
-      // Throw if the index has reached the end.
-      if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
-        throw nb::stop_iteration();
-      return DerivedT::getElement(attr.get(), nextIndex++);
-    }
+void PyArrayAttribute::PyArrayAttributeIterator::bind(nb::module_ &m) {
+  nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+      .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+      .def("__next__", &PyArrayAttributeIterator::dunderNext);
+}
 
-    /// Bind the iterator class.
-    static void bind(nb::module_ &m) {
-      nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
-          .def("__iter__", &PyDenseArrayIterator::dunderIter)
-          .def("__next__", &PyDenseArrayIterator::dunderNext);
-    }
+MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
+  return mlirArrayAttrGetElement(*this, i);
+}
 
-  private:
-    /// The referenced dense array attribute.
-    PyAttribute attr;
-    /// The next index to read.
-    int nextIndex = 0;
-  };
+void PyArrayAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const nb::list &attributes, DefaultingPyMlirContext context) {
+        SmallVector<MlirAttribute> mlirAttributes;
+        mlirAttributes.reserve(nb::len(attributes));
+        for (auto attribute : attributes) {
+          mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
+        }
+        MlirAttribute attr = mlirArrayAttrGet(
+            context->get(), mlirAttributes.size(), mlirAttributes.data());
+        return PyArrayAttribute(context->getRef(), attr);
+      },
+      nb::arg("attributes"), nb::arg("context") = nb::none(),
+      "Gets a uniqued Array attribute");
+  c.def("__getitem__",
+        [](PyArrayAttribute &arr,
+           intptr_t i) -> nb::typed<nb::object, PyAttribute> {
+          if (i >= mlirArrayAttrGetNumElements(arr))
+            throw nb::index_error("ArrayAttribute index out of range");
+          return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
+        })
+      .def("__len__",
+           [](const PyArrayAttribute &arr) {
+             return mlirArrayAttrGetNumElements(arr);
+           })
+      .def("__iter__", [](const PyArrayAttribute &arr) {
+        return PyArrayAttributeIterator(arr);
+      });
+  c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
+    std::vector<MlirAttribute> attributes;
+    intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+    attributes.reserve(numOldElements + nb::len(extras));
+    for (intptr_t i = 0; i < numOldElements; ++i)
+      attributes.push_back(arr.getItem(i));
+    for (nb::handle attr : extras)
+      attributes.push_back(pyTryCast<PyAttribute>(attr));
+    MlirAttribute arrayAttr = mlirArrayAttrGet(
+        arr.getContext()->get(), attributes.size(), attributes.data());
+    return PyArrayAttribute(arr.getContext(), arrayAttr);
+  });
+}
+void PyFloatAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyType &type, double value, DefaultingPyLocation loc) {
+        PyMlirContext::ErrorCapture errors(loc->getContext());
+        MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
+        if (mlirAttributeIsNull(attr))
+          throw MLIRError("Invalid attribute", errors.take());
+        return PyFloatAttribute(type.getContext(), attr);
+      },
+      nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
+      "Gets an uniqued float point attribute associated to a type");
+  c.def_static(
+      "get_unchecked",
+      [](PyType &type, double value, DefaultingPyMlirContext context) {
+        PyMlirContext::ErrorCapture errors(context->getRef());
+        MlirAttribute attr =
+            mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+        if (mlirAttributeIsNull(attr))
+          throw MLIRError("Invalid attribute", errors.take());
+        return PyFloatAttribute(type.getContext(), attr);
+      },
+      nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+      "Gets an uniqued float point attribute associated to a type");
+  c.def_static(
+      "get_f32",
+      [](double value, DefaultingPyMlirContext context) {
+        MlirAttribute attr = mlirFloatAttrDoubleGet(
+            context->get(), mlirF32TypeGet(context->get()), value);
+        return PyFloatAttribute(context->getRef(), attr);
+      },
+      nb::arg("value"), nb::arg("context") = nb::none(),
+      "Gets an uniqued float point attribute associated to a f32 type");
+  c.def_static(
+      "get_f64",
+      [](double value, DefaultingPyMlirContext context) {
+        MlirAttribute attr = mlirFloatAttrDoubleGet(
+            context->get(), mlirF64TypeGet(context->get()), value);
+        return PyFloatAttribute(context->getRef(), attr);
+      },
+      nb::arg("value"), nb::arg("context") = nb::none(),
+      "Gets an uniqued float point attribute associated to a f64 type");
+  c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
+                "Returns the value of the float attribute");
+  c.def("__float__", mlirFloatAttrGetValueDouble,
+        "Converts the value of the float attribute to a Python float");
+}
 
-  /// Get the element at the given index.
-  EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
-
-  /// Bind the attribute class.
-  static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
-    // Bind the constructor.
-    if constexpr (std::is_same_v<EltTy, bool>) {
-      c.def_static(
-          "get",
-          [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
-            std::vector<bool> values;
-            for (nb::handle py_value : py_values) {
-              int is_true = PyObject_IsTrue(py_value.ptr());
-              if (is_true < 0) {
-                throw nb::python_error();
-              }
-              values.push_back(is_true);
-            }
-            return getAttribute(values, ctx->getRef());
-          },
-          nb::arg("values"), nb::arg("context") = nb::none(),
-          "Gets a uniqued dense array attribute");
-    } else {
-      c.def_static(
-          "get",
-          [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
-            return getAttribute(values, ctx->getRef());
-          },
-          nb::arg("values"), nb::arg("context") = nb::none(),
-          "Gets a uniqued dense array attribute");
-    }
-    // Bind the array methods.
-    c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
-      if (i >= mlirDenseArrayGetNumElements(arr))
-        throw nb::index_error("DenseArray index out of range");
-      return arr.getItem(i);
-    });
-    c.def("__len__", [](const DerivedT &arr) {
-      return mlirDenseArrayGetNumElements(arr);
-    });
-    c.def("__iter__",
-          [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
-    c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
-      std::vector<EltTy> values;
-      intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
-      values.reserve(numOldElements + nb::len(extras));
-      for (intptr_t i = 0; i < numOldElements; ++i)
-        values.push_back(arr.getItem(i));
-      for (nb::handle attr : extras)
-        values.push_back(pyTryCast<EltTy>(attr));
-      return getAttribute(values, arr.getContext());
-    });
-  }
+void PyIntegerAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyType &type, int64_t value) {
+        MlirAttribute attr = mlirIntegerAttrGet(type, value);
+        return PyIntegerAttribute(type.getContext(), attr);
+      },
+      nb::arg("type"), nb::arg("value"),
+      "Gets an uniqued integer attribute associated to a type");
+  c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute");
+  c.def("__int__", toPyInt,
+        "Converts the value of the integer attribute to a Python int");
+  c.def_prop_ro_static(
+      "static_typeid",
+      [](nb::object & /*class*/) {
+        return PyTypeID(mlirIntegerAttrGetTypeID());
+      },
+      nb::sig("def static_typeid(/) -> TypeID"));
+}
 
-private:
-  static DerivedT getAttribute(const std::vector<EltTy> &values,
-                               PyMlirContextRef ctx) {
-    if constexpr (std::is_same_v<EltTy, bool>) {
-      std::vector<int> intValues(values.begin(), values.end());
-      MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
-                                                  intValues.data());
-      return DerivedT(ctx, attr);
-    } else {
-      MlirAttribute attr =
-          DerivedT::getAttribute(ctx->get(), values.size(), values.data());
-      return DerivedT(ctx, attr);
-    }
-  }
-};
+int64_t PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
+  MlirType type = mlirAttributeGetType(self);
+  if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+    return mlirIntegerAttrGetValueInt(self);
+  if (mlirIntegerTypeIsSigned(type))
+    return mlirIntegerAttrGetValueSInt(self);
+  return mlirIntegerAttrGetValueUInt(self);
+}
 
-/// Instantiate the python dense array classes.
-struct PyDenseBoolArrayAttribute
-    : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
-  static constexpr auto getAttribute = mlirDenseBoolArrayGet;
-  static constexpr auto getElement = mlirDenseBoolArrayGetElement;
-  static constexpr const char *pyClassName = "DenseBoolArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI8ArrayAttribute
-    : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
-  static constexpr auto getAttribute = mlirDenseI8ArrayGet;
-  static constexpr auto getElement = mlirDenseI8ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI8ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI16ArrayAttribute
-    : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
-  static constexpr auto getAttribute = mlirDenseI16ArrayGet;
-  static constexpr auto getElement = mlirDenseI16ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI16ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI32ArrayAttribute
-    : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
-  static constexpr auto getAttribute = mlirDenseI32ArrayGet;
-  static constexpr auto getElement = mlirDenseI32ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI32ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI64ArrayAttribute
-    : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
-  static constexpr auto getAttribute = mlirDenseI64ArrayGet;
-  static constexpr auto getElement = mlirDenseI64ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI64ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF32ArrayAttribute
-    : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
-  static constexpr auto getAttribute = mlirDenseF32ArrayGet;
-  static constexpr auto getElement = mlirDenseF32ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseF32ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF64ArrayAttribute
-    : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
-  static constexpr auto getAttribute = mlirDenseF64ArrayGet;
-  static constexpr auto getElement = mlirDenseF64ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseF64ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
+void PyBoolAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](bool value, DefaultingPyMlirContext context) {
+        MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+        return PyBoolAttribute(context->getRef(), attr);
+      },
+      nb::arg("value"), nb::arg("context") = nb::none(),
+      "Gets an uniqued bool attribute");
+  c.def_prop_ro("value", mlirBoolAttrGetValue,
+                "Returns the value of the bool attribute");
+  c.def("__bool__", mlirBoolAttrGetValue,
+        "Converts the value of the bool attribute to a Python bool");
+}
 
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
-  static constexpr const char *pyClassName = "ArrayAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirArrayAttrGetTypeID;
-
-  class PyArrayAttributeIterator {
-  public:
-    PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
-    PyArrayAttributeIterator &dunderIter() { return *this; }
-
-    nb::typed<nb::object, PyAttribute> dunderNext() {
-      // TODO: Throw is an inefficient way to stop iteration.
-      if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
-        throw nb::stop_iteration();
-      return PyAttribute(this->attr.getContext(),
-                         mlirArrayAttrGetElement(attr.get(), nextIndex++))
-          .maybeDownCast();
-    }
+PySymbolRefAttribute
+PySymbolRefAttribute::fromList(const std::vector<std::string> &symbols,
+                               PyMlirContext &context) {
+  if (symbols.empty())
+    throw std::runtime_error("SymbolRefAttr must be composed of at least "
+                             "one symbol.");
+  MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
+  SmallVector<MlirAttribute, 3> referenceAttrs;
+  for (size_t i = 1; i < symbols.size(); ++i) {
+    referenceAttrs.push_back(
+        mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
+  }
+  return PySymbolRefAttribute(context.getRef(),
+                              mlirSymbolRefAttrGet(context.get(), rootSymbol,
+                                                   referenceAttrs.size(),
+                                                   referenceAttrs.data()));
+}
 
-    static void bind(nb::module_ &m) {
-      nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
-          .def("__iter__", &PyArrayAttributeIterator::dunderIter)
-          .def("__next__", &PyArrayAttributeIterator::dunderNext);
-    }
+void PySymbolRefAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const std::vector<std::string> &symbols,
+         DefaultingPyMlirContext context) {
+        return PySymbolRefAttribute::fromList(symbols, context.resolve());
+      },
+      nb::arg("symbols"), nb::arg("context") = nb::none(),
+      "Gets a uniqued SymbolRef attribute from a list of symbol names");
+  c.def_prop_ro(
+      "value",
+      [](PySymbolRefAttribute &self) {
+        std::vector<std::string> symbols = {
+            unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
+        for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); ++i)
+          symbols.push_back(
+              unwrap(mlirSymbolRefAttrGetRootReference(
+                         mlirSymbolRefAttrGetNestedReference(self, i)))
+                  .str());
+        return symbols;
+      },
+      "Returns the value of the SymbolRef attribute as a list[str]");
+}
 
-  private:
-    PyAttribute attr;
-    int nextIndex = 0;
-  };
+void PyFlatSymbolRefAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const std::string &value, DefaultingPyMlirContext context) {
+        MlirAttribute attr =
+            mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+        return PyFlatSymbolRefAttribute(context->getRef(), attr);
+      },
+      nb::arg("value"), nb::arg("context") = nb::none(),
+      "Gets a uniqued FlatSymbolRef attribute");
+  c.def_prop_ro(
+      "value",
+      [](PyFlatSymbolRefAttribute &self) {
+        MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+        return nb::str(stringRef.data, stringRef.length);
+      },
+      "Returns the value of the FlatSymbolRef attribute as a string");
+}
 
-  MlirAttribute getItem(intptr_t i) {
-    return mlirArrayAttrGetElement(*this, i);
-  }
+void PyOpaqueAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const std::string &dialectNamespace, const nb_buffer &buffer,
+         PyType &type, DefaultingPyMlirContext context) {
+        const nb_buffer_info bufferInfo = buffer.request();
+        intptr_t bufferSize = bufferInfo.size;
+        MlirAttribute attr = mlirOpaqueAttrGet(
+            context->get(), toMlirStringRef(dialectNamespace), bufferSize,
+            static_cast<char *>(bufferInfo.ptr), type);
+        return PyOpaqueAttribute(context->getRef(), attr);
+      },
+      nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
+      nb::arg("context") = nb::none(),
+      // clang-format off
+        nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
+      // clang-format on
+      "Gets an Opaque attribute.");
+  c.def_prop_ro(
+      "dialect_namespace",
+      [](PyOpaqueAttribute &self) {
+        MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
+        return nb::str(stringRef.data, stringRef.length);
+      },
+      "Returns the dialect namespace for the Opaque attribute as a string");
+  c.def_prop_ro(
+      "data",
+      [](PyOpaqueAttribute &self) {
+        MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
+        return nb::bytes(stringRef.data, stringRef.length);
+      },
+      "Returns the data for the Opaqued attributes as `bytes`");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const nb::list &attributes, DefaultingPyMlirContext context) {
-          SmallVector<MlirAttribute> mlirAttributes;
-          mlirAttributes.reserve(nb::len(attributes));
-          for (auto attribute : attributes) {
-            mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
-          }
-          MlirAttribute attr = mlirArrayAttrGet(
-              context->get(), mlirAttributes.size(), mlirAttributes.data());
-          return PyArrayAttribute(context->getRef(), attr);
-        },
-        nb::arg("attributes"), nb::arg("context") = nb::none(),
-        "Gets a uniqued Array attribute");
-    c.def(
-         "__getitem__",
-         [](PyArrayAttribute &arr,
-            intptr_t i) -> nb::typed<nb::object, PyAttribute> {
-           if (i >= mlirArrayAttrGetNumElements(arr))
-             throw nb::index_error("ArrayAttribute index out of range");
-           return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
-         })
-        .def("__len__",
-             [](const PyArrayAttribute &arr) {
-               return mlirArrayAttrGetNumElements(arr);
-             })
-        .def("__iter__", [](const PyArrayAttribute &arr) {
-          return PyArrayAttributeIterator(arr);
-        });
-    c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
-      std::vector<MlirAttribute> attributes;
-      intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
-      attributes.reserve(numOldElements + nb::len(extras));
-      for (intptr_t i = 0; i < numOldElements; ++i)
-        attributes.push_back(arr.getItem(i));
-      for (nb::handle attr : extras)
-        attributes.push_back(pyTryCast<PyAttribute>(attr));
-      MlirAttribute arrayAttr = mlirArrayAttrGet(
-          arr.getContext()->get(), attributes.size(), attributes.data());
-      return PyArrayAttribute(arr.getContext(), arrayAttr);
-    });
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getFromList(const nb::list &attributes,
+                                      std::optional<PyType> explicitType,
+                                      DefaultingPyMlirContext contextWrapper) {
+  const size_t numAttributes = nb::len(attributes);
+  if (numAttributes == 0)
+    throw nb::value_error("Attributes list must be non-empty.");
+
+  MlirType shapedType;
+  if (explicitType) {
+    if ((!mlirTypeIsAShaped(*explicitType) ||
+         !mlirShapedTypeHasStaticShape(*explicitType))) {
+
+      std::string message;
+      llvm::raw_string_ostream os(message);
+      os << "Expected a static ShapedType for the shaped_type parameter: "
+         << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
+      throw nb::value_error(message.c_str());
+    }
+    shapedType = *explicitType;
+  } else {
+    SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
+    shapedType = mlirRankedTensorTypeGet(
+        shape.size(), shape.data(),
+        mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+        mlirAttributeGetNull());
   }
-};
 
-/// Float Point Attribute subclass - FloatAttr.
-class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
-  static constexpr const char *pyClassName = "FloatAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloatAttrGetTypeID;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &type, double value, DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
-          if (mlirAttributeIsNull(attr))
-            throw MLIRError("Invalid attribute", errors.take());
-          return PyFloatAttribute(type.getContext(), attr);
-        },
-        nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
-        "Gets an uniqued float point attribute associated to a type");
-    c.def_static(
-        "get_unchecked",
-        [](PyType &type, double value, DefaultingPyMlirContext context) {
-          PyMlirContext::ErrorCapture errors(context->getRef());
-          MlirAttribute attr =
-              mlirFloatAttrDoubleGet(context.get()->get(), type, value);
-          if (mlirAttributeIsNull(attr))
-            throw MLIRError("Invalid attribute", errors.take());
-          return PyFloatAttribute(type.getContext(), attr);
-        },
-        nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
-        "Gets an uniqued float point attribute associated to a type");
-    c.def_static(
-        "get_f32",
-        [](double value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirFloatAttrDoubleGet(
-              context->get(), mlirF32TypeGet(context->get()), value);
-          return PyFloatAttribute(context->getRef(), attr);
-        },
-        nb::arg("value"), nb::arg("context") = nb::none(),
-        "Gets an uniqued float point attribute associated to a f32 type");
-    c.def_static(
-        "get_f64",
-        [](double value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirFloatAttrDoubleGet(
-              context->get(), mlirF64TypeGet(context->get()), value);
-          return PyFloatAttribute(context->getRef(), attr);
-        },
-        nb::arg("value"), nb::arg("context") = nb::none(),
-        "Gets an uniqued float point attribute associated to a f64 type");
-    c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
-                  "Returns the value of the float attribute");
-    c.def("__float__", mlirFloatAttrGetValueDouble,
-          "Converts the value of the float attribute to a Python float");
+  SmallVector<MlirAttribute> mlirAttributes;
+  mlirAttributes.reserve(numAttributes);
+  for (const nb::handle &attribute : attributes) {
+    MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+    MlirType attrType = mlirAttributeGetType(mlirAttribute);
+    mlirAttributes.push_back(mlirAttribute);
+
+    if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+      std::string message;
+      llvm::raw_string_ostream os(message);
+      os << "All attributes must be of the same type and match "
+         << "the type parameter: expected="
+         << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
+         << ", but got=" << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
+      throw nb::value_error(message.c_str());
+    }
   }
-};
 
-/// Integer Attribute subclass - IntegerAttr.
-class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
-  static constexpr const char *pyClassName = "IntegerAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &type, int64_t value) {
-          MlirAttribute attr = mlirIntegerAttrGet(type, value);
-          return PyIntegerAttribute(type.getContext(), attr);
-        },
-        nb::arg("type"), nb::arg("value"),
-        "Gets an uniqued integer attribute associated to a type");
-    c.def_prop_ro("value", toPyInt,
-                  "Returns the value of the integer attribute");
-    c.def("__int__", toPyInt,
-          "Converts the value of the integer attribute to a Python int");
-    c.def_prop_ro_static(
-        "static_typeid",
-        [](nb::object & /*class*/) {
-          return PyTypeID(mlirIntegerAttrGetTypeID());
-        },
-        nanobind::sig("def static_typeid(/) -> TypeID"));
-  }
+  MlirAttribute elements = mlirDenseElementsAttrGet(
+      shapedType, mlirAttributes.size(), mlirAttributes.data());
 
-private:
-  static int64_t toPyInt(PyIntegerAttribute &self) {
-    MlirType type = mlirAttributeGetType(self);
-    if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
-      return mlirIntegerAttrGetValueInt(self);
-    if (mlirIntegerTypeIsSigned(type))
-      return mlirIntegerAttrGetValueSInt(self);
-    return mlirIntegerAttrGetValueUInt(self);
-  }
-};
+  return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
 
-/// Bool Attribute subclass - BoolAttr.
-class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
-  static constexpr const char *pyClassName = "BoolAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](bool value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
-          return PyBoolAttribute(context->getRef(), attr);
-        },
-        nb::arg("value"), nb::arg("context") = nb::none(),
-        "Gets an uniqued bool attribute");
-    c.def_prop_ro("value", mlirBoolAttrGetValue,
-                  "Returns the value of the bool attribute");
-    c.def("__bool__", mlirBoolAttrGetValue,
-          "Converts the value of the bool attribute to a Python bool");
+PyDenseElementsAttribute PyDenseElementsAttribute::getFromBuffer(
+    const nb_buffer &array, bool signless,
+    const std::optional<PyType> &explicitType,
+    std::optional<std::vector<int64_t>> explicitShape,
+    DefaultingPyMlirContext contextWrapper) {
+  // Request a contiguous view. In exotic cases, this will cause a copy.
+  int flags = PyBUF_ND;
+  if (!explicitType) {
+    flags |= PyBUF_FORMAT;
   }
-};
-
-class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
-  static constexpr const char *pyClassName = "SymbolRefAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
-                                       PyMlirContext &context) {
-    if (symbols.empty())
-      throw std::runtime_error("SymbolRefAttr must be composed of at least "
-                               "one symbol.");
-    MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
-    SmallVector<MlirAttribute, 3> referenceAttrs;
-    for (size_t i = 1; i < symbols.size(); ++i) {
-      referenceAttrs.push_back(
-          mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
-    }
-    return PySymbolRefAttribute(context.getRef(),
-                                mlirSymbolRefAttrGet(context.get(), rootSymbol,
-                                                     referenceAttrs.size(),
-                                                     referenceAttrs.data()));
+  Py_buffer view;
+  if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
+    throw nb::python_error();
   }
+  auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const std::vector<std::string> &symbols,
-           DefaultingPyMlirContext context) {
-          return PySymbolRefAttribute::fromList(symbols, context.resolve());
-        },
-        nb::arg("symbols"), nb::arg("context") = nb::none(),
-        "Gets a uniqued SymbolRef attribute from a list of symbol names");
-    c.def_prop_ro(
-        "value",
-        [](PySymbolRefAttribute &self) {
-          std::vector<std::string> symbols = {
-              unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
-          for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
-               ++i)
-            symbols.push_back(
-                unwrap(mlirSymbolRefAttrGetRootReference(
-                           mlirSymbolRefAttrGetNestedReference(self, i)))
-                    .str());
-          return symbols;
-        },
-        "Returns the value of the SymbolRef attribute as a list[str]");
+  MlirContext context = contextWrapper->get();
+  MlirAttribute attr = getAttributeFromBuffer(
+      view, signless, explicitType, std::move(explicitShape), context);
+  if (mlirAttributeIsNull(attr)) {
+    throw std::invalid_argument(
+        "DenseElementsAttr could not be constructed from the given buffer. "
+        "This may mean that the Python buffer layout does not match that "
+        "MLIR expected layout and is a bug.");
   }
-};
+  return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+}
 
-class PyFlatSymbolRefAttribute
-    : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
-  static constexpr const char *pyClassName = "FlatSymbolRefAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const std::string &value, DefaultingPyMlirContext context) {
-          MlirAttribute attr =
-              mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
-          return PyFlatSymbolRefAttribute(context->getRef(), attr);
-        },
-        nb::arg("value"), nb::arg("context") = nb::none(),
-        "Gets a uniqued FlatSymbolRef attribute");
-    c.def_prop_ro(
-        "value",
-        [](PyFlatSymbolRefAttribute &self) {
-          MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the value of the FlatSymbolRef attribute as a string");
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getSplat(const PyType &shapedType,
+                                   PyAttribute &elementAttr) {
+  auto contextWrapper =
+      PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+  if (!mlirAttributeIsAInteger(elementAttr) &&
+      !mlirAttributeIsAFloat(elementAttr)) {
+    std::string message = "Illegal element type for DenseElementsAttr: ";
+    message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+    throw nb::value_error(message.c_str());
   }
-};
-
-class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
-  static constexpr const char *pyClassName = "OpaqueAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirOpaqueAttrGetTypeID;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const std::string &dialectNamespace, const nb_buffer &buffer,
-           PyType &type, DefaultingPyMlirContext context) {
-          const nb_buffer_info bufferInfo = buffer.request();
-          intptr_t bufferSize = bufferInfo.size;
-          MlirAttribute attr = mlirOpaqueAttrGet(
-              context->get(), toMlirStringRef(dialectNamespace), bufferSize,
-              static_cast<char *>(bufferInfo.ptr), type);
-          return PyOpaqueAttribute(context->getRef(), attr);
-        },
-        nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
-        nb::arg("context") = nb::none(),
-        // clang-format off
-        nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
-        // clang-format on
-        "Gets an Opaque attribute.");
-    c.def_prop_ro(
-        "dialect_namespace",
-        [](PyOpaqueAttribute &self) {
-          MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the dialect namespace for the Opaque attribute as a string");
-    c.def_prop_ro(
-        "data",
-        [](PyOpaqueAttribute &self) {
-          MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
-          return nb::bytes(stringRef.data, stringRef.length);
-        },
-        "Returns the data for the Opaqued attributes as `bytes`");
+  if (!mlirTypeIsAShaped(shapedType) ||
+      !mlirShapedTypeHasStaticShape(shapedType)) {
+    std::string message =
+        "Expected a static ShapedType for the shaped_type parameter: ";
+    message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+    throw nb::value_error(message.c_str());
+  }
+  MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+  MlirType attrType = mlirAttributeGetType(elementAttr);
+  if (!mlirTypeEqual(shapedElementType, attrType)) {
+    std::string message =
+        "Shaped element type and attribute type must be equal: shaped=";
+    message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+    message.append(", element=");
+    message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+    throw nb::value_error(message.c_str());
   }
-};
 
-// TODO: Support construction of string elements.
-class PyDenseElementsAttribute
-    : public PyConcreteAttribute<PyDenseElementsAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
-  static constexpr const char *pyClassName = "DenseElementsAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static PyDenseElementsAttribute
-  getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
-              DefaultingPyMlirContext contextWrapper) {
-    const size_t numAttributes = nb::len(attributes);
-    if (numAttributes == 0)
-      throw nb::value_error("Attributes list must be non-empty.");
-
-    MlirType shapedType;
-    if (explicitType) {
-      if ((!mlirTypeIsAShaped(*explicitType) ||
-           !mlirShapedTypeHasStaticShape(*explicitType))) {
-
-        std::string message;
-        llvm::raw_string_ostream os(message);
-        os << "Expected a static ShapedType for the shaped_type parameter: "
-           << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
-        throw nb::value_error(message.c_str());
-      }
-      shapedType = *explicitType;
-    } else {
-      SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
-      shapedType = mlirRankedTensorTypeGet(
-          shape.size(), shape.data(),
-          mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
-          mlirAttributeGetNull());
-    }
+  MlirAttribute elements =
+      mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
+  return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
 
-    SmallVector<MlirAttribute> mlirAttributes;
-    mlirAttributes.reserve(numAttributes);
-    for (const nb::handle &attribute : attributes) {
-      MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
-      MlirType attrType = mlirAttributeGetType(mlirAttribute);
-      mlirAttributes.push_back(mlirAttribute);
-
-      if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
-        std::string message;
-        llvm::raw_string_ostream os(message);
-        os << "All attributes must be of the same type and match "
-           << "the type parameter: expected="
-           << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
-           << ", but got="
-           << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
-        throw nb::value_error(message.c_str());
-      }
-    }
+intptr_t PyDenseElementsAttribute::dunderLen() const {
+  return mlirElementsAttrGetNumElements(*this);
+}
 
-    MlirAttribute elements = mlirDenseElementsAttrGet(
-        shapedType, mlirAttributes.size(), mlirAttributes.data());
+std::unique_ptr<nb_buffer_info> PyDenseElementsAttribute::accessBuffer() {
+  MlirType shapedType = mlirAttributeGetType(*this);
+  MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+  std::string format;
 
-    return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+  if (mlirTypeIsAF32(elementType)) {
+    // f32
+    return bufferInfo<float>(shapedType);
   }
-
-  static PyDenseElementsAttribute
-  getFromBuffer(const nb_buffer &array, bool signless,
-                const std::optional<PyType> &explicitType,
-                std::optional<std::vector<int64_t>> explicitShape,
-                DefaultingPyMlirContext contextWrapper) {
-    // Request a contiguous view. In exotic cases, this will cause a copy.
-    int flags = PyBUF_ND;
-    if (!explicitType) {
-      flags |= PyBUF_FORMAT;
-    }
-    Py_buffer view;
-    if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
-      throw nb::python_error();
-    }
-    auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
-
-    MlirContext context = contextWrapper->get();
-    MlirAttribute attr = getAttributeFromBuffer(
-        view, signless, explicitType, std::move(explicitShape), context);
-    if (mlirAttributeIsNull(attr)) {
-      throw std::invalid_argument(
-          "DenseElementsAttr could not be constructed from the given buffer. "
-          "This may mean that the Python buffer layout does not match that "
-          "MLIR expected layout and is a bug.");
-    }
-    return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+  if (mlirTypeIsAF64(elementType)) {
+    // f64
+    return bufferInfo<double>(shapedType);
   }
-
-  static PyDenseElementsAttribute getSplat(const PyType &shapedType,
-                                           PyAttribute &elementAttr) {
-    auto contextWrapper =
-        PyMlirContext::forContext(mlirTypeGetContext(shapedType));
-    if (!mlirAttributeIsAInteger(elementAttr) &&
-        !mlirAttributeIsAFloat(elementAttr)) {
-      std::string message = "Illegal element type for DenseElementsAttr: ";
-      message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
-      throw nb::value_error(message.c_str());
+  if (mlirTypeIsAF16(elementType)) {
+    // f16
+    return bufferInfo<uint16_t>(shapedType, "e");
+  }
+  if (mlirTypeIsAIndex(elementType)) {
+    // Same as IndexType::kInternalStorageBitWidth
+    return bufferInfo<int64_t>(shapedType);
+  }
+  if (mlirTypeIsAInteger(elementType) &&
+      mlirIntegerTypeGetWidth(elementType) == 32) {
+    if (mlirIntegerTypeIsSignless(elementType) ||
+        mlirIntegerTypeIsSigned(elementType)) {
+      // i32
+      return bufferInfo<int32_t>(shapedType);
     }
-    if (!mlirTypeIsAShaped(shapedType) ||
-        !mlirShapedTypeHasStaticShape(shapedType)) {
-      std::string message =
-          "Expected a static ShapedType for the shaped_type parameter: ";
-      message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
-      throw nb::value_error(message.c_str());
+    if (mlirIntegerTypeIsUnsigned(elementType)) {
+      // unsigned i32
+      return bufferInfo<uint32_t>(shapedType);
     }
-    MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
-    MlirType attrType = mlirAttributeGetType(elementAttr);
-    if (!mlirTypeEqual(shapedElementType, attrType)) {
-      std::string message =
-          "Shaped element type and attribute type must be equal: shaped=";
-      message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
-      message.append(", element=");
-      message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
-      throw nb::value_error(message.c_str());
+  } else if (mlirTypeIsAInteger(elementType) &&
+             mlirIntegerTypeGetWidth(elementType) == 64) {
+    if (mlirIntegerTypeIsSignless(elementType) ||
+        mlirIntegerTypeIsSigned(elementType)) {
+      // i64
+      return bufferInfo<int64_t>(shapedType);
     }
-
-    MlirAttribute elements =
-        mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
-    return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
-  }
-
-  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
-
-  std::unique_ptr<nb_buffer_info> accessBuffer() {
-    MlirType shapedType = mlirAttributeGetType(*this);
-    MlirType elementType = mlirShapedTypeGetElementType(shapedType);
-    std::string format;
-
-    if (mlirTypeIsAF32(elementType)) {
-      // f32
-      return bufferInfo<float>(shapedType);
+    if (mlirIntegerTypeIsUnsigned(elementType)) {
+      // unsigned i64
+      return bufferInfo<uint64_t>(shapedType);
     }
-    if (mlirTypeIsAF64(elementType)) {
-      // f64
-      return bufferInfo<double>(shapedType);
+  } else if (mlirTypeIsAInteger(elementType) &&
+             mlirIntegerTypeGetWidth(elementType) == 8) {
+    if (mlirIntegerTypeIsSignless(elementType) ||
+        mlirIntegerTypeIsSigned(elementType)) {
+      // i8
+      return bufferInfo<int8_t>(shapedType);
     }
-    if (mlirTypeIsAF16(elementType)) {
-      // f16
-      return bufferInfo<uint16_t>(shapedType, "e");
+    if (mlirIntegerTypeIsUnsigned(elementType)) {
+      // unsigned i8
+      return bufferInfo<uint8_t>(shapedType);
     }
-    if (mlirTypeIsAIndex(elementType)) {
-      // Same as IndexType::kInternalStorageBitWidth
-      return bufferInfo<int64_t>(shapedType);
+  } else if (mlirTypeIsAInteger(elementType) &&
+             mlirIntegerTypeGetWidth(elementType) == 16) {
+    if (mlirIntegerTypeIsSignless(elementType) ||
+        mlirIntegerTypeIsSigned(elementType)) {
+      // i16
+      return bufferInfo<int16_t>(shapedType);
     }
-    if (mlirTypeIsAInteger(elementType) &&
-        mlirIntegerTypeGetWidth(elementType) == 32) {
-      if (mlirIntegerTypeIsSignless(elementType) ||
-          mlirIntegerTypeIsSigned(elementType)) {
-        // i32
-        return bufferInfo<int32_t>(shapedType);
-      }
-      if (mlirIntegerTypeIsUnsigned(elementType)) {
-        // unsigned i32
-        return bufferInfo<uint32_t>(shapedType);
-      }
-    } else if (mlirTypeIsAInteger(elementType) &&
-               mlirIntegerTypeGetWidth(elementType) == 64) {
-      if (mlirIntegerTypeIsSignless(elementType) ||
-          mlirIntegerTypeIsSigned(elementType)) {
-        // i64
-        return bufferInfo<int64_t>(shapedType);
-      }
-      if (mlirIntegerTypeIsUnsigned(elementType)) {
-        // unsigned i64
-        return bufferInfo<uint64_t>(shapedType);
-      }
-    } else if (mlirTypeIsAInteger(elementType) &&
-               mlirIntegerTypeGetWidth(elementType) == 8) {
-      if (mlirIntegerTypeIsSignless(elementType) ||
-          mlirIntegerTypeIsSigned(elementType)) {
-        // i8
-        return bufferInfo<int8_t>(shapedType);
-      }
-      if (mlirIntegerTypeIsUnsigned(elementType)) {
-        // unsigned i8
-        return bufferInfo<uint8_t>(shapedType);
-      }
-    } else if (mlirTypeIsAInteger(elementType) &&
-               mlirIntegerTypeGetWidth(elementType) == 16) {
-      if (mlirIntegerTypeIsSignless(elementType) ||
-          mlirIntegerTypeIsSigned(elementType)) {
-        // i16
-        return bufferInfo<int16_t>(shapedType);
-      }
-      if (mlirIntegerTypeIsUnsigned(elementType)) {
-        // unsigned i16
-        return bufferInfo<uint16_t>(shapedType);
-      }
-    } else if (mlirTypeIsAInteger(elementType) &&
-               mlirIntegerTypeGetWidth(elementType) == 1) {
-      // i1 / bool
-      // We can not send the buffer directly back to Python, because the i1
-      // values are bitpacked within MLIR. We call numpy's unpackbits function
-      // to convert the bytes.
-      return getBooleanBufferFromBitpackedAttribute();
+    if (mlirIntegerTypeIsUnsigned(elementType)) {
+      // unsigned i16
+      return bufferInfo<uint16_t>(shapedType);
     }
-
-    // TODO: Currently crashes the program.
-    // Reported as https://github.com/pybind/pybind11/issues/3336
-    throw std::invalid_argument(
-        "unsupported data type for conversion to Python buffer");
+  } else if (mlirTypeIsAInteger(elementType) &&
+             mlirIntegerTypeGetWidth(elementType) == 1) {
+    // i1 / bool
+    // We can not send the buffer directly back to Python, because the i1
+    // values are bitpacked within MLIR. We call numpy's unpackbits function
+    // to convert the bytes.
+    return getBooleanBufferFromBitpackedAttribute();
   }
 
-  static void bindDerived(ClassTy &c) {
+  // TODO: Currently crashes the program.
+  // Reported as https://github.com/pybind/pybind11/issues/3336
+  throw std::invalid_argument(
+      "unsupported data type for conversion to Python buffer");
+}
+
+void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
 #if PY_VERSION_HEX < 0x03090000
-    PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
-    tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
-    tp->tp_as_buffer->bf_releasebuffer =
-        PyDenseElementsAttribute::bf_releasebuffer;
+  PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
+  tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
+  tp->tp_as_buffer->bf_releasebuffer =
+      PyDenseElementsAttribute::bf_releasebuffer;
 #endif
-    c.def("__len__", &PyDenseElementsAttribute::dunderLen)
-        .def_static(
-            "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
-            nb::arg("signless") = true, nb::arg("type") = nb::none(),
-            nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
-            // clang-format off
+  c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+      .def_static(
+          "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
+          nb::arg("signless") = true, nb::arg("type") = nb::none(),
+          nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
+          // clang-format off
             nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
-            // clang-format on
-            kDenseElementsAttrGetDocstring)
-        .def_static("get", PyDenseElementsAttribute::getFromList,
-                    nb::arg("attrs"), nb::arg("type") = nb::none(),
-                    nb::arg("context") = nb::none(),
-                    kDenseElementsAttrGetFromListDocstring)
-        .def_static("get_splat", PyDenseElementsAttribute::getSplat,
-                    nb::arg("shaped_type"), nb::arg("element_attr"),
-                    "Gets a DenseElementsAttr where all values are the same")
-        .def_prop_ro("is_splat",
-                     [](PyDenseElementsAttribute &self) -> bool {
-                       return mlirDenseElementsAttrIsSplat(self);
-                     })
-        .def("get_splat_value",
-             [](PyDenseElementsAttribute &self)
-                 -> nb::typed<nb::object, PyAttribute> {
-               if (!mlirDenseElementsAttrIsSplat(self))
-                 throw nb::value_error(
-                     "get_splat_value called on a non-splat attribute");
-               return PyAttribute(self.getContext(),
-                                  mlirDenseElementsAttrGetSplatValue(self))
-                   .maybeDownCast();
-             });
-  }
-
-  static PyType_Slot slots[];
+          // clang-format on
+          kDenseElementsAttrGetDocstring)
+      .def_static("get", PyDenseElementsAttribute::getFromList,
+                  nb::arg("attrs"), nb::arg("type") = nb::none(),
+                  nb::arg("context") = nb::none(),
+                  kDenseElementsAttrGetFromListDocstring)
+      .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+                  nb::arg("shaped_type"), nb::arg("element_attr"),
+                  "Gets a DenseElementsAttr where all values are the same")
+      .def_prop_ro("is_splat",
+                   [](PyDenseElementsAttribute &self) -> bool {
+                     return mlirDenseElementsAttrIsSplat(self);
+                   })
+      .def("get_splat_value",
+           [](PyDenseElementsAttribute &self)
+               -> nb::typed<nb::object, PyAttribute> {
+             if (!mlirDenseElementsAttrIsSplat(self))
+               throw nb::value_error(
+                   "get_splat_value called on a non-splat attribute");
+             return PyAttribute(self.getContext(),
+                                mlirDenseElementsAttrGetSplatValue(self))
+                 .maybeDownCast();
+           });
+}
 
-private:
-  static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
-  static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+bool PyDenseElementsAttribute::isUnsignedIntegerFormat(
+    std::string_view format) {
+  if (format.empty())
+    return false;
+  char code = format[0];
+  return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+         code == 'Q';
+}
 
-  static bool isUnsignedIntegerFormat(std::string_view format) {
-    if (format.empty())
-      return false;
-    char code = format[0];
-    return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
-           code == 'Q';
-  }
+bool PyDenseElementsAttribute::isSignedIntegerFormat(std::string_view format) {
+  if (format.empty())
+    return false;
+  char code = format[0];
+  return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+         code == 'q';
+}
 
-  static bool isSignedIntegerFormat(std::string_view format) {
-    if (format.empty())
-      return false;
-    char code = format[0];
-    return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
-           code == 'q';
+MlirType PyDenseElementsAttribute::getShapedType(
+    std::optional<MlirType> bulkLoadElementType,
+    std::optional<std::vector<int64_t>> explicitShape, Py_buffer &view) {
+  SmallVector<int64_t> shape;
+  if (explicitShape) {
+    shape.append(explicitShape->begin(), explicitShape->end());
+  } else {
+    shape.append(view.shape, view.shape + view.ndim);
   }
 
-  static MlirType
-  getShapedType(std::optional<MlirType> bulkLoadElementType,
-                std::optional<std::vector<int64_t>> explicitShape,
-                Py_buffer &view) {
-    SmallVector<int64_t> shape;
+  if (mlirTypeIsAShaped(*bulkLoadElementType)) {
     if (explicitShape) {
-      shape.append(explicitShape->begin(), explicitShape->end());
-    } else {
-      shape.append(view.shape, view.shape + view.ndim);
-    }
-
-    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
-      if (explicitShape) {
-        throw std::invalid_argument("Shape can only be specified explicitly "
-                                    "when the type is not a shaped type.");
-      }
-      return *bulkLoadElementType;
+      throw std::invalid_argument("Shape can only be specified explicitly "
+                                  "when the type is not a shaped type.");
     }
-    MlirAttribute encodingAttr = mlirAttributeGetNull();
-    return mlirRankedTensorTypeGet(shape.size(), shape.data(),
-                                   *bulkLoadElementType, encodingAttr);
+    return *bulkLoadElementType;
   }
+  MlirAttribute encodingAttr = mlirAttributeGetNull();
+  return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+                                 *bulkLoadElementType, encodingAttr);
+}
 
-  static MlirAttribute getAttributeFromBuffer(
-      Py_buffer &view, bool signless, std::optional<PyType> explicitType,
-      const std::optional<std::vector<int64_t>> &explicitShape,
-      MlirContext &context) {
-    // Detect format codes that are suitable for bulk loading. This includes
-    // all byte aligned integer and floating point types up to 8 bytes.
-    // Notably, this excludes exotics types which do not have a direct
-    // representation in the buffer protocol (i.e. complex, etc).
-    std::optional<MlirType> bulkLoadElementType;
-    if (explicitType) {
-      bulkLoadElementType = *explicitType;
-    } else {
-      std::string_view format(view.format);
-      if (format == "f") {
-        // f32
-        assert(view.itemsize == 4 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF32TypeGet(context);
-      } else if (format == "d") {
-        // f64
-        assert(view.itemsize == 8 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF64TypeGet(context);
-      } else if (format == "e") {
-        // f16
-        assert(view.itemsize == 2 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF16TypeGet(context);
-      } else if (format == "?") {
-        // i1
-        // The i1 type needs to be bit-packed, so we will handle it separately
-        return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
-                                                      context);
-      } else if (isSignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeSignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeSignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
-                                         : mlirIntegerTypeSignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeSignedGet(context, 16);
-        }
-      } else if (isUnsignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // unsigned i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeUnsignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // unsigned i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeUnsignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 8)
-                                    : mlirIntegerTypeUnsignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeUnsignedGet(context, 16);
-        }
+MlirAttribute PyDenseElementsAttribute::getAttributeFromBuffer(
+    Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+    const std::optional<std::vector<int64_t>> &explicitShape,
+    MlirContext &context) {
+  // Detect format codes that are suitable for bulk loading. This includes
+  // all byte aligned integer and floating point types up to 8 bytes.
+  // Notably, this excludes exotics types which do not have a direct
+  // representation in the buffer protocol (i.e. complex, etc).
+  std::optional<MlirType> bulkLoadElementType;
+  if (explicitType) {
+    bulkLoadElementType = *explicitType;
+  } else {
+    std::string_view format(view.format);
+    if (format == "f") {
+      // f32
+      assert(view.itemsize == 4 && "mismatched array itemsize");
+      bulkLoadElementType = mlirF32TypeGet(context);
+    } else if (format == "d") {
+      // f64
+      assert(view.itemsize == 8 && "mismatched array itemsize");
+      bulkLoadElementType = mlirF64TypeGet(context);
+    } else if (format == "e") {
+      // f16
+      assert(view.itemsize == 2 && "mismatched array itemsize");
+      bulkLoadElementType = mlirF16TypeGet(context);
+    } else if (format == "?") {
+      // i1
+      // The i1 type needs to be bit-packed, so we will handle it separately
+      return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+                                                    context);
+    } else if (isSignedIntegerFormat(format)) {
+      if (view.itemsize == 4) {
+        // i32
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+                                       : mlirIntegerTypeSignedGet(context, 32);
+      } else if (view.itemsize == 8) {
+        // i64
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+                                       : mlirIntegerTypeSignedGet(context, 64);
+      } else if (view.itemsize == 1) {
+        // i8
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                       : mlirIntegerTypeSignedGet(context, 8);
+      } else if (view.itemsize == 2) {
+        // i16
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+                                       : mlirIntegerTypeSignedGet(context, 16);
       }
-      if (!bulkLoadElementType) {
-        throw std::invalid_argument(
-            std::string("unimplemented array format conversion from format: ") +
-            std::string(format));
+    } else if (isUnsignedIntegerFormat(format)) {
+      if (view.itemsize == 4) {
+        // unsigned i32
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 32)
+                                  : mlirIntegerTypeUnsignedGet(context, 32);
+      } else if (view.itemsize == 8) {
+        // unsigned i64
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 64)
+                                  : mlirIntegerTypeUnsignedGet(context, 64);
+      } else if (view.itemsize == 1) {
+        // i8
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                       : mlirIntegerTypeUnsignedGet(context, 8);
+      } else if (view.itemsize == 2) {
+        // i16
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 16)
+                                  : mlirIntegerTypeUnsignedGet(context, 16);
       }
     }
-
-    MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
-    return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
-  }
-
-  // There is a complication for boolean numpy arrays, as numpy represents
-  // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
-  // booleans per byte.
-  static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
-      Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
-      MlirContext &context) {
-    if (llvm::endianness::native != llvm::endianness::little) {
-      // Given we have no good way of testing the behavior on big-endian
-      // systems we will throw
-      throw nb::type_error("Constructing a bit-packed MLIR attribute is "
-                           "unsupported on big-endian systems");
+    if (!bulkLoadElementType) {
+      throw std::invalid_argument(
+          std::string("unimplemented array format conversion from format: ") +
+          std::string(format));
     }
-    nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
-        /*data=*/static_cast<uint8_t *>(view.buf),
-        /*shape=*/{static_cast<size_t>(view.len)});
-
-    nb::module_ numpy = nb::module_::import_("numpy");
-    nb::object packbitsFunc = numpy.attr("packbits");
-    nb::object packedBooleans =
-        packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
-    nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
-
-    MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
-                                           std::move(explicitShape), view);
-    assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
-    // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
-    // packedBooleans, hence the MlirAttribute will remain valid even when
-    // packedBooleans get reclaimed by the end of the function.
-    return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
-                                             pythonBuffer.ptr);
   }
 
-  // This does the opposite transformation of
-  // `getBitpackedAttributeFromBooleanBuffer`
-  std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
-    if (llvm::endianness::native != llvm::endianness::little) {
-      // Given we have no good way of testing the behavior on big-endian
-      // systems we will throw
-      throw nb::type_error("Constructing a numpy array from a MLIR attribute "
-                           "is unsupported on big-endian systems");
-    }
+  MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+  return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+}
 
-    int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
-    int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
-    uint8_t *bitpackedData = static_cast<uint8_t *>(
-        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
-    nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
-        /*data=*/bitpackedData,
-        /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
-
-    nb::module_ numpy = nb::module_::import_("numpy");
-    nb::object unpackbitsFunc = numpy.attr("unpackbits");
-    nb::object equalFunc = numpy.attr("equal");
-    nb::object reshapeFunc = numpy.attr("reshape");
-    nb::object unpackedBooleans =
-        unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
-
-    // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
-    // We need to:
-    //   1. Slice away the padded bits
-    //   2. Make the boolean array have the correct shape
-    //   3. Convert the array to a boolean array
-    unpackedBooleans = unpackedBooleans[nb::slice(
-        nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
-    unpackedBooleans = equalFunc(unpackedBooleans, 1);
-
-    MlirType shapedType = mlirAttributeGetType(*this);
-    intptr_t rank = mlirShapedTypeGetRank(shapedType);
-    std::vector<intptr_t> shape(rank);
-    for (intptr_t i = 0; i < rank; ++i) {
-      shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
-    }
-    unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+MlirAttribute PyDenseElementsAttribute::getBitpackedAttributeFromBooleanBuffer(
+    Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+    MlirContext &context) {
+  if (llvm::endianness::native != llvm::endianness::little) {
+    // Given we have no good way of testing the behavior on big-endian
+    // systems we will throw
+    throw nb::type_error("Constructing a bit-packed MLIR attribute is "
+                         "unsupported on big-endian systems");
+  }
+  nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
+      /*data=*/static_cast<uint8_t *>(view.buf),
+      /*shape=*/{static_cast<size_t>(view.len)});
+
+  nb::module_ numpy = nb::module_::import_("numpy");
+  nb::object packbitsFunc = numpy.attr("packbits");
+  nb::object packedBooleans =
+      packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
+  nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
+
+  MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
+                                         std::move(explicitShape), view);
+  assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
+  // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
+  // packedBooleans, hence the MlirAttribute will remain valid even when
+  // packedBooleans get reclaimed by the end of the function.
+  return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+                                           pythonBuffer.ptr);
+}
 
-    // Make sure the returned nb::buffer_view claims ownership of the data in
-    // `pythonBuffer` so it remains valid when Python reads it
-    nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
-    return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+std::unique_ptr<nb_buffer_info>
+PyDenseElementsAttribute::getBooleanBufferFromBitpackedAttribute() const {
+  if (llvm::endianness::native != llvm::endianness::little) {
+    // Given we have no good way of testing the behavior on big-endian
+    // systems we will throw
+    throw nb::type_error("Constructing a numpy array from a MLIR attribute "
+                         "is unsupported on big-endian systems");
   }
 
-  template <typename Type>
-  std::unique_ptr<nb_buffer_info>
-  bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
-    intptr_t rank = mlirShapedTypeGetRank(shapedType);
-    // Prepare the data for the buffer_info.
-    // Buffer is configured for read-only access below.
-    Type *data = static_cast<Type *>(
-        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
-    // Prepare the shape for the buffer_info.
-    SmallVector<intptr_t, 4> shape;
-    for (intptr_t i = 0; i < rank; ++i)
-      shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
-    // Prepare the strides for the buffer_info.
-    SmallVector<intptr_t, 4> strides;
-    if (mlirDenseElementsAttrIsSplat(*this)) {
-      // Splats are special, only the single value is stored.
-      strides.assign(rank, 0);
-    } else {
-      for (intptr_t i = 1; i < rank; ++i) {
-        intptr_t strideFactor = 1;
-        for (intptr_t j = i; j < rank; ++j)
-          strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
-        strides.push_back(sizeof(Type) * strideFactor);
-      }
-      strides.push_back(sizeof(Type));
-    }
-    const char *format;
-    if (explicitFormat) {
-      format = explicitFormat;
-    } else {
-      format = nb_format_descriptor<Type>::format();
-    }
-    return std::make_unique<nb_buffer_info>(
-        data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
-        /*readonly=*/true);
+  int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+  int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
+  uint8_t *bitpackedData = static_cast<uint8_t *>(
+      const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+  nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
+      /*data=*/bitpackedData,
+      /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
+
+  nb::module_ numpy = nb::module_::import_("numpy");
+  nb::object unpackbitsFunc = numpy.attr("unpackbits");
+  nb::object equalFunc = numpy.attr("equal");
+  nb::object reshapeFunc = numpy.attr("reshape");
+  nb::object unpackedBooleans =
+      unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
+
+  // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
+  // We need to:
+  //   1. Slice away the padded bits
+  //   2. Make the boolean array have the correct shape
+  //   3. Convert the array to a boolean array
+  unpackedBooleans = unpackedBooleans[nb::slice(
+      nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
+  unpackedBooleans = equalFunc(unpackedBooleans, 1);
+
+  MlirType shapedType = mlirAttributeGetType(*this);
+  intptr_t rank = mlirShapedTypeGetRank(shapedType);
+  std::vector<intptr_t> shape(rank);
+  for (intptr_t i = 0; i < rank; ++i) {
+    shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
   }
-}; // namespace
+  unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+
+  // Make sure the returned nb::buffer_view claims ownership of the data
+  // in `pythonBuffer` so it remains valid when Python reads it
+  nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
+  return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+}
 
 PyType_Slot PyDenseElementsAttribute::slots[] = {
 // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
@@ -1333,364 +965,294 @@ PyType_Slot PyDenseElementsAttribute::slots[] = {
   delete reinterpret_cast<nb_buffer_info *>(view->internal);
 }
 
-/// Refinement of the PyDenseElementsAttribute for attributes containing
-/// integer (and boolean) values. Supports element access.
-class PyDenseIntElementsAttribute
-    : public PyConcreteAttribute<PyDenseIntElementsAttribute,
-                                 PyDenseElementsAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
-  static constexpr const char *pyClassName = "DenseIntElementsAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  /// Returns the element at the given linear position. Asserts if the index
-  /// is out of range.
-  nb::int_ dunderGetItem(intptr_t pos) {
-    if (pos < 0 || pos >= dunderLen()) {
-      throw nb::index_error("attempt to access out of bounds element");
-    }
+nb::int_ PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) const {
+  if (pos < 0 || pos >= dunderLen()) {
+    throw nb::index_error("attempt to access out of bounds element");
+  }
 
-    MlirType type = mlirAttributeGetType(*this);
-    type = mlirShapedTypeGetElementType(type);
-    // Index type can also appear as a DenseIntElementsAttr and therefore can be
-    // casted to integer.
-    assert(mlirTypeIsAInteger(type) ||
-           mlirTypeIsAIndex(type) && "expected integer/index element type in "
-                                     "dense int elements attribute");
-    // Dispatch element extraction to an appropriate C function based on the
-    // elemental type of the attribute. nb::int_ is implicitly constructible
-    // from any C++ integral type and handles bitwidth correctly.
-    // TODO: consider caching the type properties in the constructor to avoid
-    // querying them on each element access.
-    if (mlirTypeIsAIndex(type)) {
-      return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+  MlirType type = mlirAttributeGetType(*this);
+  type = mlirShapedTypeGetElementType(type);
+  // Index type can also appear as a DenseIntElementsAttr and therefore can be
+  // casted to integer.
+  assert(mlirTypeIsAInteger(type) ||
+         mlirTypeIsAIndex(type) && "expected integer/index element type in "
+                                   "dense int elements attribute");
+  // Dispatch element extraction to an appropriate C function based on the
+  // elemental type of the attribute. nb::int_ is implicitly
+  // constructible from any C++ integral type and handles bitwidth correctly.
+  // TODO: consider caching the type properties in the constructor to avoid
+  // querying them on each element access.
+  if (mlirTypeIsAIndex(type)) {
+    return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+  }
+  unsigned width = mlirIntegerTypeGetWidth(type);
+  bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+  if (isUnsigned) {
+    if (width == 1) {
+      return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
     }
-    unsigned width = mlirIntegerTypeGetWidth(type);
-    bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
-    if (isUnsigned) {
-      if (width == 1) {
-        return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
-      }
-      if (width == 8) {
-        return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
-      }
-      if (width == 16) {
-        return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
-      }
-      if (width == 32) {
-        return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
-      }
-      if (width == 64) {
-        return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
-      }
-    } else {
-      if (width == 1) {
-        return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
-      }
-      if (width == 8) {
-        return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
-      }
-      if (width == 16) {
-        return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
-      }
-      if (width == 32) {
-        return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
-      }
-      if (width == 64) {
-        return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
-      }
+    if (width == 8) {
+      return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
+    }
+    if (width == 16) {
+      return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
+    }
+    if (width == 32) {
+      return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
+    }
+    if (width == 64) {
+      return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
+    }
+  } else {
+    if (width == 1) {
+      return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
+    }
+    if (width == 8) {
+      return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
+    }
+    if (width == 16) {
+      return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
+    }
+    if (width == 32) {
+      return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
+    }
+    if (width == 64) {
+      return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
     }
-    throw nb::type_error("Unsupported integer type");
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
   }
-};
+  throw nb::type_error("Unsupported integer type");
+}
 
+void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) {
+  c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+}
 // Check if the python version is less than 3.13. Py_IsFinalizing is a part
 // of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
 #if PY_VERSION_HEX < 0x030d0000
 #define Py_IsFinalizing _Py_IsFinalizing
 #endif
 
-class PyDenseResourceElementsAttribute
-    : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction =
-      mlirAttributeIsADenseResourceElements;
-  static constexpr const char *pyClassName = "DenseResourceElementsAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  static PyDenseResourceElementsAttribute
-  getFromBuffer(const nb_buffer &buffer, const std::string &name,
-                const PyType &type, std::optional<size_t> alignment,
-                bool isMutable, DefaultingPyMlirContext contextWrapper) {
-    if (!mlirTypeIsAShaped(type)) {
-      throw std::invalid_argument(
-          "Constructing a DenseResourceElementsAttr requires a ShapedType.");
-    }
-
-    // Do not request any conversions as we must ensure to use caller
-    // managed memory.
-    int flags = PyBUF_STRIDES;
-    std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
-    if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
-      throw nb::python_error();
-    }
+PyDenseResourceElementsAttribute
+PyDenseResourceElementsAttribute::getFromBuffer(
+    const nb_buffer &buffer, const std::string &name, const PyType &type,
+    std::optional<size_t> alignment, bool isMutable,
+    DefaultingPyMlirContext contextWrapper) {
+  if (!mlirTypeIsAShaped(type)) {
+    throw std::invalid_argument(
+        "Constructing a DenseResourceElementsAttr requires a ShapedType.");
+  }
 
-    // This scope releaser will only release if we haven't yet transferred
-    // ownership.
-    auto freeBuffer = llvm::make_scope_exit([&]() {
-      if (view)
-        PyBuffer_Release(view.get());
-    });
+  // Do not request any conversions as we must ensure to use caller
+  // managed memory.
+  int flags = PyBUF_STRIDES;
+  std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
+  if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
+    throw nb::python_error();
+  }
 
-    if (!PyBuffer_IsContiguous(view.get(), 'A')) {
-      throw std::invalid_argument("Contiguous buffer is required.");
-    }
+  // This scope releaser will only release if we haven't yet transferred
+  // ownership.
+  auto freeBuffer = llvm::make_scope_exit([&]() {
+    if (view)
+      PyBuffer_Release(view.get());
+  });
 
-    // Infer alignment to be the stride of one element if not explicit.
-    size_t inferredAlignment;
-    if (alignment)
-      inferredAlignment = *alignment;
-    else
-      inferredAlignment = view->strides[view->ndim - 1];
-
-    // The userData is a Py_buffer* that the deleter owns.
-    auto deleter = [](void *userData, const void *data, size_t size,
-                      size_t align) {
-      if (Py_IsFinalizing())
-        return;
-      assert(Py_IsInitialized() && "expected interpreter to be initialized");
-      Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
-      nb::gil_scoped_acquire gil;
-      PyBuffer_Release(ownedView);
-      delete ownedView;
-    };
-
-    size_t rawBufferSize = view->len;
-    MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
-        type, toMlirStringRef(name), view->buf, rawBufferSize,
-        inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
-    if (mlirAttributeIsNull(attr)) {
-      throw std::invalid_argument(
-          "DenseResourceElementsAttr could not be constructed from the given "
-          "buffer. "
-          "This may mean that the Python buffer layout does not match that "
-          "MLIR expected layout and is a bug.");
-    }
-    view.release();
-    return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+  if (!PyBuffer_IsContiguous(view.get(), 'A')) {
+    throw std::invalid_argument("Contiguous buffer is required.");
   }
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get_from_buffer",
-                 PyDenseResourceElementsAttribute::getFromBuffer,
-                 nb::arg("array"), nb::arg("name"), nb::arg("type"),
-                 nb::arg("alignment") = nb::none(),
-                 nb::arg("is_mutable") = false, nb::arg("context") = nb::none(),
-                 // clang-format off
-                 nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
-                 // clang-format on
-                 kDenseResourceElementsAttrGetFromBufferDocstring);
+  // Infer alignment to be the stride of one element if not explicit.
+  size_t inferredAlignment;
+  if (alignment)
+    inferredAlignment = *alignment;
+  else
+    inferredAlignment = view->strides[view->ndim - 1];
+
+  // The userData is a Py_buffer* that the deleter owns.
+  auto deleter = [](void *userData, const void *data, size_t size,
+                    size_t align) {
+    if (Py_IsFinalizing())
+      return;
+    assert(Py_IsInitialized() && "expected interpreter to be initialized");
+    Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
+    nb::gil_scoped_acquire gil;
+    PyBuffer_Release(ownedView);
+    delete ownedView;
+  };
+
+  size_t rawBufferSize = view->len;
+  MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
+      type, toMlirStringRef(name), view->buf, rawBufferSize, inferredAlignment,
+      isMutable, deleter, static_cast<void *>(view.get()));
+  if (mlirAttributeIsNull(attr)) {
+    throw std::invalid_argument(
+        "DenseResourceElementsAttr could not be constructed from the given "
+        "buffer. "
+        "This may mean that the Python buffer layout does not match that "
+        "MLIR expected layout and is a bug.");
   }
-};
+  view.release();
+  return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+}
 
-class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
-  static constexpr const char *pyClassName = "DictAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirDictionaryAttrGetTypeID;
+void PyDenseResourceElementsAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
+      nb::arg("array"), nb::arg("name"), nb::arg("type"),
+      nb::arg("alignment") = nb::none(), nb::arg("is_mutable") = false,
+      nb::arg("context") = nb::none(),
+      // clang-format off
+      nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
+      // clang-format on
+      kDenseResourceElementsAttrGetFromBufferDocstring);
+}
 
-  intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+intptr_t PyDictAttribute::dunderLen() const {
+  return mlirDictionaryAttrGetNumElements(*this);
+}
 
-  bool dunderContains(const std::string &name) {
-    return !mlirAttributeIsNull(
-        mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
-  }
+bool PyDictAttribute::dunderContains(const std::string &name) const {
+  return !mlirAttributeIsNull(
+      mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def("__contains__", &PyDictAttribute::dunderContains);
-    c.def("__len__", &PyDictAttribute::dunderLen);
-    c.def_static(
-        "get",
-        [](const nb::dict &attributes, DefaultingPyMlirContext context) {
-          SmallVector<MlirNamedAttribute> mlirNamedAttributes;
-          mlirNamedAttributes.reserve(attributes.size());
-          for (std::pair<nb::handle, nb::handle> it : attributes) {
-            auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
-            auto name = nb::cast<std::string>(it.first);
-            mlirNamedAttributes.push_back(mlirNamedAttributeGet(
-                mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
-                                  toMlirStringRef(name)),
-                mlirAttr));
-          }
+void PyDictAttribute::bindDerived(ClassTy &c) {
+  c.def("__contains__", &PyDictAttribute::dunderContains);
+  c.def("__len__", &PyDictAttribute::dunderLen);
+  c.def_static(
+      "get",
+      [](const nb::dict &attributes, DefaultingPyMlirContext context) {
+        SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+        mlirNamedAttributes.reserve(attributes.size());
+        for (std::pair<nb::handle, nb::handle> it : attributes) {
+          auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
+          auto name = nb::cast<std::string>(it.first);
+          mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+              mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
+                                toMlirStringRef(name)),
+              mlirAttr));
+        }
+        MlirAttribute attr =
+            mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+                                  mlirNamedAttributes.data());
+        return PyDictAttribute(context->getRef(), attr);
+      },
+      nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
+      "Gets an uniqued dict attribute");
+  c.def("__getitem__",
+        [](PyDictAttribute &self,
+           const std::string &name) -> nb::typed<nb::object, PyAttribute> {
           MlirAttribute attr =
-              mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
-                                    mlirNamedAttributes.data());
-          return PyDictAttribute(context->getRef(), attr);
-        },
-        nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
-        "Gets an uniqued dict attribute");
-    c.def("__getitem__",
-          [](PyDictAttribute &self,
-             const std::string &name) -> nb::typed<nb::object, PyAttribute> {
-            MlirAttribute attr =
-                mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
-            if (mlirAttributeIsNull(attr))
-              throw nb::key_error("attempt to access a non-existent attribute");
-            return PyAttribute(self.getContext(), attr).maybeDownCast();
-          });
-    c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
-      if (index < 0 || index >= self.dunderLen()) {
-        throw nb::index_error("attempt to access out of bounds attribute");
-      }
-      MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
-      return PyNamedAttribute(
-          namedAttr.attribute,
-          std::string(mlirIdentifierStr(namedAttr.name).data));
-    });
-  }
-};
-
-/// Refinement of PyDenseElementsAttribute for attributes containing
-/// floating-point values. Supports element access.
-class PyDenseFPElementsAttribute
-    : public PyConcreteAttribute<PyDenseFPElementsAttribute,
-                                 PyDenseElementsAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
-  static constexpr const char *pyClassName = "DenseFPElementsAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-
-  nb::float_ dunderGetItem(intptr_t pos) {
-    if (pos < 0 || pos >= dunderLen()) {
-      throw nb::index_error("attempt to access out of bounds element");
+              mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+          if (mlirAttributeIsNull(attr))
+            throw nb::key_error("attempt to access a non-existent attribute");
+          return PyAttribute(self.getContext(), attr).maybeDownCast();
+        });
+  c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+    if (index < 0 || index >= self.dunderLen()) {
+      throw nb::index_error("attempt to access out of bounds attribute");
     }
+    MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+    return PyNamedAttribute(
+        namedAttr.attribute,
+        std::string(mlirIdentifierStr(namedAttr.name).data));
+  });
+}
 
-    MlirType type = mlirAttributeGetType(*this);
-    type = mlirShapedTypeGetElementType(type);
-    // Dispatch element extraction to an appropriate C function based on the
-    // elemental type of the attribute. nb::float_ is implicitly constructible
-    // from float and double.
-    // TODO: consider caching the type properties in the constructor to avoid
-    // querying them on each element access.
-    if (mlirTypeIsAF32(type)) {
-      return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
-    }
-    if (mlirTypeIsAF64(type)) {
-      return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
-    }
-    throw nb::type_error("Unsupported floating-point type");
+nb::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) const {
+  if (pos < 0 || pos >= dunderLen()) {
+    throw nb::index_error("attempt to access out of bounds element");
   }
 
-  static void bindDerived(ClassTy &c) {
-    c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+  MlirType type = mlirAttributeGetType(*this);
+  type = mlirShapedTypeGetElementType(type);
+  // Dispatch element extraction to an appropriate C function based on the
+  // elemental type of the attribute. nb::float_ is implicitly
+  // constructible from float and double.
+  // TODO: consider caching the type properties in the constructor to avoid
+  // querying them on each element access.
+  if (mlirTypeIsAF32(type)) {
+    return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
   }
-};
-
-class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
-  static constexpr const char *pyClassName = "TypeAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirTypeAttrGetTypeID;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const PyType &value, DefaultingPyMlirContext context) {
-          MlirAttribute attr = mlirTypeAttrGet(value.get());
-          return PyTypeAttribute(context->getRef(), attr);
-        },
-        nb::arg("value"), nb::arg("context") = nb::none(),
-        "Gets a uniqued Type attribute");
-    c.def_prop_ro(
-        "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
-          return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
-              .maybeDownCast();
-        });
+  if (mlirTypeIsAF64(type)) {
+    return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
   }
-};
+  throw nb::type_error("Unsupported floating-point type");
+}
 
-/// Unit Attribute subclass. Unit attributes don't have values.
-class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
-  static constexpr const char *pyClassName = "UnitAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirUnitAttrGetTypeID;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          return PyUnitAttribute(context->getRef(),
-                                 mlirUnitAttrGet(context->get()));
-        },
-        nb::arg("context") = nb::none(), "Create a Unit attribute.");
-  }
-};
+void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) {
+  c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+}
 
-/// Strided layout attribute subclass.
-class PyStridedLayoutAttribute
-    : public PyConcreteAttribute<PyStridedLayoutAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
-  static constexpr const char *pyClassName = "StridedLayoutAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirStridedLayoutAttrGetTypeID;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](int64_t offset, const std::vector<int64_t> &strides,
-           DefaultingPyMlirContext ctx) {
-          MlirAttribute attr = mlirStridedLayoutAttrGet(
-              ctx->get(), offset, strides.size(), strides.data());
-          return PyStridedLayoutAttribute(ctx->getRef(), attr);
-        },
-        nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
-        "Gets a strided layout attribute.");
-    c.def_static(
-        "get_fully_dynamic",
-        [](int64_t rank, DefaultingPyMlirContext ctx) {
-          auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
-          std::vector<int64_t> strides(rank);
-          llvm::fill(strides, dynamic);
-          MlirAttribute attr = mlirStridedLayoutAttrGet(
-              ctx->get(), dynamic, strides.size(), strides.data());
-          return PyStridedLayoutAttribute(ctx->getRef(), attr);
-        },
-        nb::arg("rank"), nb::arg("context") = nb::none(),
-        "Gets a strided layout attribute with dynamic offset and strides of "
-        "a "
-        "given rank.");
-    c.def_prop_ro(
-        "offset",
-        [](PyStridedLayoutAttribute &self) {
-          return mlirStridedLayoutAttrGetOffset(self);
-        },
-        "Returns the value of the float point attribute");
-    c.def_prop_ro(
-        "strides",
-        [](PyStridedLayoutAttribute &self) {
-          intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
-          std::vector<int64_t> strides(size);
-          for (intptr_t i = 0; i < size; i++) {
-            strides[i] = mlirStridedLayoutAttrGetStride(self, i);
-          }
-          return strides;
-        },
-        "Returns the value of the float point attribute");
-  }
-};
+void PyTypeAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const PyType &value, DefaultingPyMlirContext context) {
+        MlirAttribute attr = mlirTypeAttrGet(value.get());
+        return PyTypeAttribute(context->getRef(), attr);
+      },
+      nb::arg("value"), nb::arg("context") = nb::none(),
+      "Gets a uniqued Type attribute");
+  c.def_prop_ro(
+      "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+        return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+            .maybeDownCast();
+      });
+}
+
+void PyUnitAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        return PyUnitAttribute(context->getRef(),
+                               mlirUnitAttrGet(context->get()));
+      },
+      nb::arg("context") = nb::none(), "Create a Unit attribute.");
+}
+
+void PyStridedLayoutAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](int64_t offset, const std::vector<int64_t> &strides,
+         DefaultingPyMlirContext ctx) {
+        MlirAttribute attr = mlirStridedLayoutAttrGet(
+            ctx->get(), offset, strides.size(), strides.data());
+        return PyStridedLayoutAttribute(ctx->getRef(), attr);
+      },
+      nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
+      "Gets a strided layout attribute.");
+  c.def_static(
+      "get_fully_dynamic",
+      [](int64_t rank, DefaultingPyMlirContext ctx) {
+        auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
+        std::vector<int64_t> strides(rank);
+        llvm::fill(strides, dynamic);
+        MlirAttribute attr = mlirStridedLayoutAttrGet(
+            ctx->get(), dynamic, strides.size(), strides.data());
+        return PyStridedLayoutAttribute(ctx->getRef(), attr);
+      },
+      nb::arg("rank"), nb::arg("context") = nb::none(),
+      "Gets a strided layout attribute with dynamic offset and strides of "
+      "a "
+      "given rank.");
+  c.def_prop_ro(
+      "offset",
+      [](PyStridedLayoutAttribute &self) {
+        return mlirStridedLayoutAttrGetOffset(self);
+      },
+      "Returns the value of the float point attribute");
+  c.def_prop_ro(
+      "strides",
+      [](PyStridedLayoutAttribute &self) {
+        intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+        std::vector<int64_t> strides(size);
+        for (intptr_t i = 0; i < size; i++) {
+          strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+        }
+        return strides;
+      },
+      "Returns the value of the float point attribute");
+}
 
 nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
   if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
@@ -1747,10 +1309,6 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
   throw nb::type_error(msg.c_str());
 }
 
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
 void PyStringAttribute::bindDerived(ClassTy &c) {
   c.def_static(
       "get",
@@ -1795,9 +1353,6 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
       "Returns the value of the string attribute as `bytes`");
 }
 
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 void populateIRAttributes(nb::module_ &m) {
   PyAffineMapAttribute::bind(m);
   PyDenseBoolArrayAttribute::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7350046f428c7..ca56fc3248ed8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -29,490 +29,269 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
+int mlirTypeIsAIntegerOrFloat(MlirType type) {
   return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
 }
 
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirIntegerTypeGetTypeID;
-  static constexpr const char *pyClassName = "IntegerType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_signless",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create a signless integer type");
-    c.def_static(
-        "get_signed",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create a signed integer type");
-    c.def_static(
-        "get_unsigned",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create an unsigned integer type");
-    c.def_prop_ro(
-        "width",
-        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
-        "Returns the width of the integer type");
-    c.def_prop_ro(
-        "is_signless",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSignless(self);
-        },
-        "Returns whether this is a signless integer");
-    c.def_prop_ro(
-        "is_signed",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSigned(self);
-        },
-        "Returns whether this is a signed integer");
-    c.def_prop_ro(
-        "is_unsigned",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsUnsigned(self);
-        },
-        "Returns whether this is an unsigned integer");
-  }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirIndexTypeGetTypeID;
-  static constexpr const char *pyClassName = "IndexType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirIndexTypeGet(context->get());
-          return PyIndexType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a index type.");
-  }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
-  static constexpr const char *pyClassName = "FloatType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro(
-        "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
-        "Returns the width of the floating-point type");
-  }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
-    : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat4E2M1FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float4E2M1FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
-          return PyFloat4E2M1FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
-    : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat6E2M3FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float6E2M3FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
-          return PyFloat6E2M3FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
-    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat6E3M2FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float6E3M2FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
-          return PyFloat6E3M2FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
-    : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
-          return PyFloat8E4M3FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E5M2TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E5M2Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E5M2TypeGet(context->get());
-          return PyFloat8E5M2Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3TypeGet(context->get());
-          return PyFloat8E4M3Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
-    : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
-          return PyFloat8E4M3FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
-    : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3B11FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
-          return PyFloat8E4M3B11FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
-    : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E5M2FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
-  using PyConcreteType::PyConcreteType;
+void PyIntegerType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get_signless",
+      [](unsigned width, DefaultingPyMlirContext context) {
+        MlirType t = mlirIntegerTypeGet(context->get(), width);
+        return PyIntegerType(context->getRef(), t);
+      },
+      nb::arg("width"), nb::arg("context") = nb::none(),
+      "Create a signless integer type");
+  c.def_static(
+      "get_signed",
+      [](unsigned width, DefaultingPyMlirContext context) {
+        MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+        return PyIntegerType(context->getRef(), t);
+      },
+      nb::arg("width"), nb::arg("context") = nb::none(),
+      "Create a signed integer type");
+  c.def_static(
+      "get_unsigned",
+      [](unsigned width, DefaultingPyMlirContext context) {
+        MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+        return PyIntegerType(context->getRef(), t);
+      },
+      nb::arg("width"), nb::arg("context") = nb::none(),
+      "Create an unsigned integer type");
+  c.def_prop_ro(
+      "width",
+      [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+      "Returns the width of the integer type");
+  c.def_prop_ro(
+      "is_signless",
+      [](PyIntegerType &self) -> bool {
+        return mlirIntegerTypeIsSignless(self);
+      },
+      "Returns whether this is a signless integer");
+  c.def_prop_ro(
+      "is_signed",
+      [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); },
+      "Returns whether this is a signed integer");
+  c.def_prop_ro(
+      "is_unsigned",
+      [](PyIntegerType &self) -> bool {
+        return mlirIntegerTypeIsUnsigned(self);
+      },
+      "Returns whether this is an unsigned integer");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
-          return PyFloat8E5M2FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
-  }
-};
+void PyIndexType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirIndexTypeGet(context->get());
+        return PyIndexType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a index type.");
+}
 
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E3M4TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E3M4Type";
-  using PyConcreteType::PyConcreteType;
+void PyFloatType::bindDerived(ClassTy &c) {
+  c.def_prop_ro(
+      "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+      "Returns the width of the floating-point type");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E3M4TypeGet(context->get());
-          return PyFloat8E3M4Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
-  }
-};
+void PyFloat4E2M1FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+        return PyFloat4E2M1FNType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
+}
 
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
-    : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E8M0FNUTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E8M0FNUType";
-  using PyConcreteType::PyConcreteType;
+void PyFloat6E2M3FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+        return PyFloat6E2M3FNType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
-          return PyFloat8E8M0FNUType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
-  }
-};
+void PyFloat6E3M2FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+        return PyFloat6E3M2FNType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
+}
 
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirBFloat16TypeGetTypeID;
-  static constexpr const char *pyClassName = "BF16Type";
-  using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+        return PyFloat8E4M3FNType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirBF16TypeGet(context->get());
-          return PyBF16Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a bf16 type.");
-  }
-};
+void PyFloat8E5M2Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E5M2TypeGet(context->get());
+        return PyFloat8E5M2Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
+}
 
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat16TypeGetTypeID;
-  static constexpr const char *pyClassName = "F16Type";
-  using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3TypeGet(context->get());
+        return PyFloat8E4M3Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF16TypeGet(context->get());
-          return PyF16Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f16 type.");
-  }
-};
+void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+        return PyFloat8E4M3FNUZType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
+}
 
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloatTF32TypeGetTypeID;
-  static constexpr const char *pyClassName = "FloatTF32Type";
-  using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+        return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirTF32TypeGet(context->get());
-          return PyTF32Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a tf32 type.");
-  }
-};
+void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+        return PyFloat8E5M2FNUZType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
+}
 
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat32TypeGetTypeID;
-  static constexpr const char *pyClassName = "F32Type";
-  using PyConcreteType::PyConcreteType;
+void PyFloat8E3M4Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E3M4TypeGet(context->get());
+        return PyFloat8E3M4Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF32TypeGet(context->get());
-          return PyF32Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f32 type.");
-  }
-};
+void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+        return PyFloat8E8M0FNUType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
+}
 
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat64TypeGetTypeID;
-  static constexpr const char *pyClassName = "F64Type";
-  using PyConcreteType::PyConcreteType;
+void PyBF16Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirBF16TypeGet(context->get());
+        return PyBF16Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a bf16 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF64TypeGet(context->get());
-          return PyF64Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f64 type.");
-  }
-};
+void PyF16Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirF16TypeGet(context->get());
+        return PyF16Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a f16 type.");
+}
 
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirNoneTypeGetTypeID;
-  static constexpr const char *pyClassName = "NoneType";
-  using PyConcreteType::PyConcreteType;
+void PyTF32Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirTF32TypeGet(context->get());
+        return PyTF32Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a tf32 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirNoneTypeGet(context->get());
-          return PyNoneType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a none type.");
-  }
-};
+void PyF32Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirF32TypeGet(context->get());
+        return PyF32Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a f32 type.");
+}
 
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirComplexTypeGetTypeID;
-  static constexpr const char *pyClassName = "ComplexType";
-  using PyConcreteType::PyConcreteType;
+void PyF64Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirF64TypeGet(context->get());
+        return PyF64Type(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a f64 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &elementType) {
-          // The element must be a floating point or integer scalar type.
-          if (mlirTypeIsAIntegerOrFloat(elementType)) {
-            MlirType t = mlirComplexTypeGet(elementType);
-            return PyComplexType(elementType.getContext(), t);
-          }
-          throw nb::value_error(
-              (Twine("invalid '") +
-               nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
-               "' and expected floating point or integer type.")
-                  .str()
-                  .c_str());
-        },
-        "Create a complex type");
-    c.def_prop_ro(
-        "element_type",
-        [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
-          return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
-              .maybeDownCast();
-        },
-        "Returns element type.");
-  }
-};
+void PyNoneType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirNoneTypeGet(context->get());
+        return PyNoneType(context->getRef(), t);
+      },
+      nb::arg("context") = nb::none(), "Create a none type.");
+}
 
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+void PyComplexType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyType &elementType) {
+        // The element must be a floating point or integer scalar type.
+        if (mlirTypeIsAIntegerOrFloat(elementType)) {
+          MlirType t = mlirComplexTypeGet(elementType);
+          return PyComplexType(elementType.getContext(), t);
+        }
+        throw nb::value_error(
+            (Twine("invalid '") +
+             nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
+             "' and expected floating point or integer type.")
+                .str()
+                .c_str());
+      },
+      "Create a complex type");
+  c.def_prop_ro(
+      "element_type",
+      [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
+        return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+            .maybeDownCast();
+      },
+      "Returns element type.");
+}
 
 // Shaped Type Interface - ShapedType
 void PyShapedType::bindDerived(ClassTy &c) {
@@ -629,526 +408,424 @@ void PyShapedType::requireHasRank() {
 
 const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
 
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirVectorTypeGetTypeID;
-  static constexpr const char *pyClassName = "VectorType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
-                 nb::arg("element_type"), nb::kw_only(),
-                 nb::arg("scalable") = nb::none(),
-                 nb::arg("scalable_dims") = nb::none(),
-                 nb::arg("loc") = nb::none(), "Create a vector type")
-        .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
-                    nb::arg("element_type"), nb::kw_only(),
-                    nb::arg("scalable") = nb::none(),
-                    nb::arg("scalable_dims") = nb::none(),
-                    nb::arg("context") = nb::none(), "Create a vector type")
-        .def_prop_ro(
-            "scalable",
-            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
-        .def_prop_ro("scalable_dims", [](MlirType self) {
-          std::vector<bool> scalableDims;
-          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
-          scalableDims.reserve(rank);
-          for (size_t i = 0; i < rank; ++i)
-            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
-          return scalableDims;
-        });
-  }
-
-private:
-  static PyVectorType
-  getChecked(std::vector<int64_t> shape, PyType &elementType,
-             std::optional<nb::list> scalable,
-             std::optional<std::vector<int64_t>> scalableDims,
-             DefaultingPyLocation loc) {
-    if (scalable && scalableDims) {
-      throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
-                            "are mutually exclusive.");
-    }
-
-    PyMlirContext::ErrorCapture errors(loc->getContext());
-    MlirType type;
-    if (scalable) {
-      if (scalable->size() != shape.size())
-        throw nb::value_error("Expected len(scalable) == len(shape).");
-
-      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
-          *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
-      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
-                                              scalableDimFlags.data(),
-                                              elementType);
-    } else if (scalableDims) {
-      SmallVector<bool> scalableDimFlags(shape.size(), false);
-      for (int64_t dim : *scalableDims) {
-        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
-          throw nb::value_error("Scalable dimension index out of bounds.");
-        scalableDimFlags[dim] = true;
-      }
-      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
-                                              scalableDimFlags.data(),
-                                              elementType);
-    } else {
-      type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                      elementType);
-    }
-    if (mlirTypeIsNull(type))
-      throw MLIRError("Invalid type", errors.take());
-    return PyVectorType(elementType.getContext(), type);
-  }
+void PyVectorType::bindDerived(ClassTy &c) {
+  c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
+               nb::arg("element_type"), nb::kw_only(),
+               nb::arg("scalable") = nb::none(),
+               nb::arg("scalable_dims") = nb::none(),
+               nb::arg("loc") = nb::none(), "Create a vector type")
+      .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
+                  nb::arg("element_type"), nb::kw_only(),
+                  nb::arg("scalable") = nb::none(),
+                  nb::arg("scalable_dims") = nb::none(),
+                  nb::arg("context") = nb::none(), "Create a vector type")
+      .def_prop_ro("scalable",
+                   [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+      .def_prop_ro("scalable_dims", [](MlirType self) {
+        std::vector<bool> scalableDims;
+        size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+        scalableDims.reserve(rank);
+        for (size_t i = 0; i < rank; ++i)
+          scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+        return scalableDims;
+      });
+}
 
-  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
-                          std::optional<nb::list> scalable,
-                          std::optional<std::vector<int64_t>> scalableDims,
-                          DefaultingPyMlirContext context) {
-    if (scalable && scalableDims) {
-      throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
-                            "are mutually exclusive.");
+PyVectorType
+PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
+                         std::optional<nb::list> scalable,
+                         std::optional<std::vector<int64_t>> scalableDims,
+                         DefaultingPyLocation loc) {
+  if (scalable && scalableDims) {
+    throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+                          "are mutually exclusive.");
+  }
+
+  PyMlirContext::ErrorCapture errors(loc->getContext());
+  MlirType type;
+  if (scalable) {
+    if (scalable->size() != shape.size())
+      throw nb::value_error("Expected len(scalable) == len(shape).");
+
+    SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+        *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+    type = mlirVectorTypeGetScalableChecked(
+        loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+  } else if (scalableDims) {
+    SmallVector<bool> scalableDimFlags(shape.size(), false);
+    for (int64_t dim : *scalableDims) {
+      if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+        throw nb::value_error("Scalable dimension index out of bounds.");
+      scalableDimFlags[dim] = true;
     }
+    type = mlirVectorTypeGetScalableChecked(
+        loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+  } else {
+    type =
+        mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType);
+  }
+  if (mlirTypeIsNull(type))
+    throw MLIRError("Invalid type", errors.take());
+  return PyVectorType(elementType.getContext(), type);
+}
 
-    PyMlirContext::ErrorCapture errors(context->getRef());
-    MlirType type;
-    if (scalable) {
-      if (scalable->size() != shape.size())
-        throw nb::value_error("Expected len(scalable) == len(shape).");
-
-      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
-          *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
-      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
-                                       scalableDimFlags.data(), elementType);
-    } else if (scalableDims) {
-      SmallVector<bool> scalableDimFlags(shape.size(), false);
-      for (int64_t dim : *scalableDims) {
-        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
-          throw nb::value_error("Scalable dimension index out of bounds.");
-        scalableDimFlags[dim] = true;
-      }
-      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
-                                       scalableDimFlags.data(), elementType);
-    } else {
-      type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
+                               std::optional<nb::list> scalable,
+                               std::optional<std::vector<int64_t>> scalableDims,
+                               DefaultingPyMlirContext context) {
+  if (scalable && scalableDims) {
+    throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+                          "are mutually exclusive.");
+  }
+
+  PyMlirContext::ErrorCapture errors(context->getRef());
+  MlirType type;
+  if (scalable) {
+    if (scalable->size() != shape.size())
+      throw nb::value_error("Expected len(scalable) == len(shape).");
+
+    SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+        *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+    type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                     scalableDimFlags.data(), elementType);
+  } else if (scalableDims) {
+    SmallVector<bool> scalableDimFlags(shape.size(), false);
+    for (int64_t dim : *scalableDims) {
+      if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+        throw nb::value_error("Scalable dimension index out of bounds.");
+      scalableDimFlags[dim] = true;
     }
-    if (mlirTypeIsNull(type))
-      throw MLIRError("Invalid type", errors.take());
-    return PyVectorType(elementType.getContext(), type);
-  }
-};
-
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
-    : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirRankedTensorTypeGetTypeID;
-  static constexpr const char *pyClassName = "RankedTensorType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirRankedTensorTypeGetChecked(
-              loc, shape.size(), shape.data(), elementType,
-              encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyRankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("shape"), nb::arg("element_type"),
-        nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
-        "Create a ranked tensor type");
-    c.def_static(
-        "get_unchecked",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           std::optional<PyAttribute> &encodingAttr,
-           DefaultingPyMlirContext context) {
-          PyMlirContext::ErrorCapture errors(context->getRef());
-          MlirType t = mlirRankedTensorTypeGet(
-              shape.size(), shape.data(), elementType,
-              encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyRankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("shape"), nb::arg("element_type"),
-        nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
-        "Create a ranked tensor type");
-    c.def_prop_ro(
-        "encoding",
-        [](PyRankedTensorType &self)
-            -> std::optional<nb::typed<nb::object, PyAttribute>> {
-          MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
-          if (mlirAttributeIsNull(encoding))
-            return std::nullopt;
-          return PyAttribute(self.getContext(), encoding).maybeDownCast();
-        });
-  }
-};
-
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
-    : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirUnrankedTensorTypeGetTypeID;
-  static constexpr const char *pyClassName = "UnrankedTensorType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &elementType, DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyUnrankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("element_type"), nb::arg("loc") = nb::none(),
-        "Create a unranked tensor type");
-    c.def_static(
-        "get_unchecked",
-        [](PyType &elementType, DefaultingPyMlirContext context) {
-          PyMlirContext::ErrorCapture errors(context->getRef());
-          MlirType t = mlirUnrankedTensorTypeGet(elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyUnrankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("element_type"), nb::arg("context") = nb::none(),
-        "Create a unranked tensor type");
-  }
-};
-
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirMemRefTypeGetTypeID;
-  static constexpr const char *pyClassName = "MemRefType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-         "get",
-         [](std::vector<int64_t> shape, PyType &elementType,
-            PyAttribute *layout, PyAttribute *memorySpace,
-            DefaultingPyLocation loc) {
-           PyMlirContext::ErrorCapture errors(loc->getContext());
-           MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
-           MlirAttribute memSpaceAttr =
-               memorySpace ? *memorySpace : mlirAttributeGetNull();
-           MlirType t =
-               mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
-                                        shape.data(), layoutAttr, memSpaceAttr);
-           if (mlirTypeIsNull(t))
-             throw MLIRError("Invalid type", errors.take());
-           return PyMemRefType(elementType.getContext(), t);
-         },
-         nb::arg("shape"), nb::arg("element_type"),
-         nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
-         nb::arg("loc") = nb::none(), "Create a memref type")
-        .def_static(
-            "get_unchecked",
-            [](std::vector<int64_t> shape, PyType &elementType,
-               PyAttribute *layout, PyAttribute *memorySpace,
-               DefaultingPyMlirContext context) {
-              PyMlirContext::ErrorCapture errors(context->getRef());
-              MlirAttribute layoutAttr =
-                  layout ? *layout : mlirAttributeGetNull();
-              MlirAttribute memSpaceAttr =
-                  memorySpace ? *memorySpace : mlirAttributeGetNull();
-              MlirType t =
-                  mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
-                                    layoutAttr, memSpaceAttr);
-              if (mlirTypeIsNull(t))
-                throw MLIRError("Invalid type", errors.take());
-              return PyMemRefType(elementType.getContext(), t);
-            },
-            nb::arg("shape"), nb::arg("element_type"),
-            nb::arg("layout") = nb::none(),
-            nb::arg("memory_space") = nb::none(),
-            nb::arg("context") = nb::none(), "Create a memref type")
-        .def_prop_ro(
-            "layout",
-            [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
-              return PyAttribute(self.getContext(),
-                                 mlirMemRefTypeGetLayout(self))
-                  .maybeDownCast();
-            },
-            "The layout of the MemRef type.")
-        .def(
-            "get_strides_and_offset",
-            [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
-              std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
-              int64_t offset;
-              if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
-                      self, strides.data(), &offset)))
-                throw std::runtime_error(
-                    "Failed to extract strides and offset from memref.");
-              return {strides, offset};
-            },
-            "The strides and offset of the MemRef type.")
-        .def_prop_ro(
-            "affine_map",
-            [](PyMemRefType &self) -> PyAffineMap {
-              MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
-              return PyAffineMap(self.getContext(), map);
-            },
-            "The layout of the MemRef type as an affine map.")
-        .def_prop_ro(
-            "memory_space",
-            [](PyMemRefType &self)
-                -> std::optional<nb::typed<nb::object, PyAttribute>> {
-              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
-              if (mlirAttributeIsNull(a))
-                return std::nullopt;
-              return PyAttribute(self.getContext(), a).maybeDownCast();
-            },
-            "Returns the memory space of the given MemRef type.");
-  }
-};
-
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
-    : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirUnrankedMemRefTypeGetTypeID;
-  static constexpr const char *pyClassName = "UnrankedMemRefType";
-  using PyConcreteType::PyConcreteType;
+    type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                     scalableDimFlags.data(), elementType);
+  } else {
+    type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+  }
+  if (mlirTypeIsNull(type))
+    throw MLIRError("Invalid type", errors.take());
+  return PyVectorType(elementType.getContext(), type);
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-         "get",
-         [](PyType &elementType, PyAttribute *memorySpace,
-            DefaultingPyLocation loc) {
-           PyMlirContext::ErrorCapture errors(loc->getContext());
-           MlirAttribute memSpaceAttr = {};
-           if (memorySpace)
-             memSpaceAttr = *memorySpace;
+void PyRankedTensorType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](std::vector<int64_t> shape, PyType &elementType,
+         std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+        PyMlirContext::ErrorCapture errors(loc->getContext());
+        MlirType t = mlirRankedTensorTypeGetChecked(
+            loc, shape.size(), shape.data(), elementType,
+            encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyRankedTensorType(elementType.getContext(), t);
+      },
+      nb::arg("shape"), nb::arg("element_type"),
+      nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
+      "Create a ranked tensor type");
+  c.def_static(
+      "get_unchecked",
+      [](std::vector<int64_t> shape, PyType &elementType,
+         std::optional<PyAttribute> &encodingAttr,
+         DefaultingPyMlirContext context) {
+        PyMlirContext::ErrorCapture errors(context->getRef());
+        MlirType t = mlirRankedTensorTypeGet(
+            shape.size(), shape.data(), elementType,
+            encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyRankedTensorType(elementType.getContext(), t);
+      },
+      nb::arg("shape"), nb::arg("element_type"),
+      nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+      "Create a ranked tensor type");
+  c.def_prop_ro(
+      "encoding",
+      [](PyRankedTensorType &self)
+          -> std::optional<nb::typed<nb::object, PyAttribute>> {
+        MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+        if (mlirAttributeIsNull(encoding))
+          return std::nullopt;
+        return PyAttribute(self.getContext(), encoding).maybeDownCast();
+      });
+}
 
-           MlirType t =
-               mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
-           if (mlirTypeIsNull(t))
-             throw MLIRError("Invalid type", errors.take());
-           return PyUnrankedMemRefType(elementType.getContext(), t);
-         },
-         nb::arg("element_type"), nb::arg("memory_space").none(),
-         nb::arg("loc") = nb::none(), "Create a unranked memref type")
-        .def_static(
-            "get_unchecked",
-            [](PyType &elementType, PyAttribute *memorySpace,
-               DefaultingPyMlirContext context) {
-              PyMlirContext::ErrorCapture errors(context->getRef());
-              MlirAttribute memSpaceAttr = {};
-              if (memorySpace)
-                memSpaceAttr = *memorySpace;
+void PyUnrankedTensorType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyType &elementType, DefaultingPyLocation loc) {
+        PyMlirContext::ErrorCapture errors(loc->getContext());
+        MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyUnrankedTensorType(elementType.getContext(), t);
+      },
+      nb::arg("element_type"), nb::arg("loc") = nb::none(),
+      "Create a unranked tensor type");
+  c.def_static(
+      "get_unchecked",
+      [](PyType &elementType, DefaultingPyMlirContext context) {
+        PyMlirContext::ErrorCapture errors(context->getRef());
+        MlirType t = mlirUnrankedTensorTypeGet(elementType);
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyUnrankedTensorType(elementType.getContext(), t);
+      },
+      nb::arg("element_type"), nb::arg("context") = nb::none(),
+      "Create a unranked tensor type");
+}
 
-              MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
-              if (mlirTypeIsNull(t))
-                throw MLIRError("Invalid type", errors.take());
-              return PyUnrankedMemRefType(elementType.getContext(), t);
-            },
-            nb::arg("element_type"), nb::arg("memory_space").none(),
-            nb::arg("context") = nb::none(), "Create a unranked memref type")
-        .def_prop_ro(
-            "memory_space",
-            [](PyUnrankedMemRefType &self)
-                -> std::optional<nb::typed<nb::object, PyAttribute>> {
-              MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
-              if (mlirAttributeIsNull(a))
-                return std::nullopt;
-              return PyAttribute(self.getContext(), a).maybeDownCast();
-            },
-            "Returns the memory space of the given Unranked MemRef type.");
-  }
-};
+void PyMemRefType::bindDerived(ClassTy &c) {
+  c.def_static(
+       "get",
+       [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout,
+          PyAttribute *memorySpace, DefaultingPyLocation loc) {
+         PyMlirContext::ErrorCapture errors(loc->getContext());
+         MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+         MlirAttribute memSpaceAttr =
+             memorySpace ? *memorySpace : mlirAttributeGetNull();
+         MlirType t =
+             mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+                                      shape.data(), layoutAttr, memSpaceAttr);
+         if (mlirTypeIsNull(t))
+           throw MLIRError("Invalid type", errors.take());
+         return PyMemRefType(elementType.getContext(), t);
+       },
+       nb::arg("shape"), nb::arg("element_type"),
+       nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
+       nb::arg("loc") = nb::none(), "Create a memref type")
+      .def_static(
+          "get_unchecked",
+          [](std::vector<int64_t> shape, PyType &elementType,
+             PyAttribute *layout, PyAttribute *memorySpace,
+             DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
+            MlirAttribute layoutAttr =
+                layout ? *layout : mlirAttributeGetNull();
+            MlirAttribute memSpaceAttr =
+                memorySpace ? *memorySpace : mlirAttributeGetNull();
+            MlirType t =
+                mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+                                  layoutAttr, memSpaceAttr);
+            if (mlirTypeIsNull(t))
+              throw MLIRError("Invalid type", errors.take());
+            return PyMemRefType(elementType.getContext(), t);
+          },
+          nb::arg("shape"), nb::arg("element_type"),
+          nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
+          nb::arg("context") = nb::none(), "Create a memref type")
+      .def_prop_ro(
+          "layout",
+          [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
+            return PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
+                .maybeDownCast();
+          },
+          "The layout of the MemRef type.")
+      .def(
+          "get_strides_and_offset",
+          [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+            std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+            int64_t offset;
+            if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+                    self, strides.data(), &offset)))
+              throw std::runtime_error(
+                  "Failed to extract strides and offset from memref.");
+            return {strides, offset};
+          },
+          "The strides and offset of the MemRef type.")
+      .def_prop_ro(
+          "affine_map",
+          [](PyMemRefType &self) -> PyAffineMap {
+            MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+            return PyAffineMap(self.getContext(), map);
+          },
+          "The layout of the MemRef type as an affine map.")
+      .def_prop_ro(
+          "memory_space",
+          [](PyMemRefType &self)
+              -> std::optional<nb::typed<nb::object, PyAttribute>> {
+            MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+            if (mlirAttributeIsNull(a))
+              return std::nullopt;
+            return PyAttribute(self.getContext(), a).maybeDownCast();
+          },
+          "Returns the memory space of the given MemRef type.");
+}
 
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirTupleTypeGetTypeID;
-  static constexpr const char *pyClassName = "TupleType";
-  using PyConcreteType::PyConcreteType;
+void PyUnrankedMemRefType::bindDerived(ClassTy &c) {
+  c.def_static(
+       "get",
+       [](PyType &elementType, PyAttribute *memorySpace,
+          DefaultingPyLocation loc) {
+         PyMlirContext::ErrorCapture errors(loc->getContext());
+         MlirAttribute memSpaceAttr = {};
+         if (memorySpace)
+           memSpaceAttr = *memorySpace;
+
+         MlirType t =
+             mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+         if (mlirTypeIsNull(t))
+           throw MLIRError("Invalid type", errors.take());
+         return PyUnrankedMemRefType(elementType.getContext(), t);
+       },
+       nb::arg("element_type"), nb::arg("memory_space").none(),
+       nb::arg("loc") = nb::none(), "Create a unranked memref type")
+      .def_static(
+          "get_unchecked",
+          [](PyType &elementType, PyAttribute *memorySpace,
+             DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
+            MlirAttribute memSpaceAttr = {};
+            if (memorySpace)
+              memSpaceAttr = *memorySpace;
+
+            MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+            if (mlirTypeIsNull(t))
+              throw MLIRError("Invalid type", errors.take());
+            return PyUnrankedMemRefType(elementType.getContext(), t);
+          },
+          nb::arg("element_type"), nb::arg("memory_space").none(),
+          nb::arg("context") = nb::none(), "Create a unranked memref type")
+      .def_prop_ro(
+          "memory_space",
+          [](PyUnrankedMemRefType &self)
+              -> std::optional<nb::typed<nb::object, PyAttribute>> {
+            MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+            if (mlirAttributeIsNull(a))
+              return std::nullopt;
+            return PyAttribute(self.getContext(), a).maybeDownCast();
+          },
+          "Returns the memory space of the given Unranked MemRef type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_tuple",
-        [](const std::vector<PyType> &elements,
-           DefaultingPyMlirContext context) {
-          std::vector<MlirType> mlirElements;
-          mlirElements.reserve(elements.size());
-          for (const auto &element : elements)
-            mlirElements.push_back(element.get());
-          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
-                                        mlirElements.data());
-          return PyTupleType(context->getRef(), t);
-        },
-        nb::arg("elements"), nb::arg("context") = nb::none(),
-        "Create a tuple type");
-    c.def_static(
-        "get_tuple",
-        [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
-          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
-                                        elements.data());
-          return PyTupleType(context->getRef(), t);
-        },
-        nb::arg("elements"), nb::arg("context") = nb::none(),
-        // clang-format off
+void PyTupleType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get_tuple",
+      [](const std::vector<PyType> &elements, DefaultingPyMlirContext context) {
+        std::vector<MlirType> mlirElements;
+        mlirElements.reserve(elements.size());
+        for (const auto &element : elements)
+          mlirElements.push_back(element.get());
+        MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+                                      mlirElements.data());
+        return PyTupleType(context->getRef(), t);
+      },
+      nb::arg("elements"), nb::arg("context") = nb::none(),
+      "Create a tuple type");
+  c.def_static(
+      "get_tuple",
+      [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+        MlirType t =
+            mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+        return PyTupleType(context->getRef(), t);
+      },
+      nb::arg("elements"), nb::arg("context") = nb::none(),
+      // clang-format off
         nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
-        // clang-format on
-        "Create a tuple type");
-    c.def(
-        "get_type",
-        [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
-          return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
-              .maybeDownCast();
-        },
-        nb::arg("pos"), "Returns the pos-th type in the tuple type.");
-    c.def_prop_ro(
-        "num_types",
-        [](PyTupleType &self) -> intptr_t {
-          return mlirTupleTypeGetNumTypes(self);
-        },
-        "Returns the number of types contained in a tuple.");
-  }
-};
-
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFunctionTypeGetTypeID;
-  static constexpr const char *pyClassName = "FunctionType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<PyType> inputs, std::vector<PyType> results,
-           DefaultingPyMlirContext context) {
-          std::vector<MlirType> mlirInputs;
-          mlirInputs.reserve(inputs.size());
-          for (const auto &input : inputs)
-            mlirInputs.push_back(input.get());
-          std::vector<MlirType> mlirResults;
-          mlirResults.reserve(results.size());
-          for (const auto &result : results)
-            mlirResults.push_back(result.get());
+      // clang-format on
+      "Create a tuple type");
+  c.def(
+      "get_type",
+      [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
+        return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+            .maybeDownCast();
+      },
+      nb::arg("pos"), "Returns the pos-th type in the tuple type.");
+  c.def_prop_ro(
+      "num_types",
+      [](PyTupleType &self) -> intptr_t {
+        return mlirTupleTypeGetNumTypes(self);
+      },
+      "Returns the number of types contained in a tuple.");
+}
 
-          MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
-                                           mlirInputs.data(), results.size(),
-                                           mlirResults.data());
-          return PyFunctionType(context->getRef(), t);
-        },
-        nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
-        "Gets a FunctionType from a list of input and result types");
-    c.def_static(
-        "get",
-        [](std::vector<MlirType> inputs, std::vector<MlirType> results,
-           DefaultingPyMlirContext context) {
-          MlirType t =
-              mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
-                                  results.size(), results.data());
-          return PyFunctionType(context->getRef(), t);
-        },
-        nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
-        // clang-format off
+void PyFunctionType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](std::vector<PyType> inputs, std::vector<PyType> results,
+         DefaultingPyMlirContext context) {
+        std::vector<MlirType> mlirInputs;
+        mlirInputs.reserve(inputs.size());
+        for (const auto &input : inputs)
+          mlirInputs.push_back(input.get());
+        std::vector<MlirType> mlirResults;
+        mlirResults.reserve(results.size());
+        for (const auto &result : results)
+          mlirResults.push_back(result.get());
+
+        MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
+                                         mlirInputs.data(), results.size(),
+                                         mlirResults.data());
+        return PyFunctionType(context->getRef(), t);
+      },
+      nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+      "Gets a FunctionType from a list of input and result types");
+  c.def_static(
+      "get",
+      [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+         DefaultingPyMlirContext context) {
+        MlirType t =
+            mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+                                results.size(), results.data());
+        return PyFunctionType(context->getRef(), t);
+      },
+      nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+      // clang-format off
         nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
-        // clang-format on
-        "Gets a FunctionType from a list of input and result types");
-    c.def_prop_ro(
-        "inputs",
-        [](PyFunctionType &self) {
-          MlirType t = self;
-          nb::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
-               ++i) {
-            types.append(mlirFunctionTypeGetInput(t, i));
-          }
-          return types;
-        },
-        "Returns the list of input types in the FunctionType.");
-    c.def_prop_ro(
-        "results",
-        [](PyFunctionType &self) {
-          nb::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
-               ++i) {
-            types.append(mlirFunctionTypeGetResult(self, i));
-          }
-          return types;
-        },
-        "Returns the list of result types in the FunctionType.");
-  }
-};
-
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirOpaqueTypeGetTypeID;
-  static constexpr const char *pyClassName = "OpaqueType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const std::string &dialectNamespace, const std::string &typeData,
-           DefaultingPyMlirContext context) {
-          MlirType type = mlirOpaqueTypeGet(context->get(),
-                                            toMlirStringRef(dialectNamespace),
-                                            toMlirStringRef(typeData));
-          return PyOpaqueType(context->getRef(), type);
-        },
-        nb::arg("dialect_namespace"), nb::arg("buffer"),
-        nb::arg("context") = nb::none(),
-        "Create an unregistered (opaque) dialect type.");
-    c.def_prop_ro(
-        "dialect_namespace",
-        [](PyOpaqueType &self) {
-          MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the dialect namespace for the Opaque type as a string.");
-    c.def_prop_ro(
-        "data",
-        [](PyOpaqueType &self) {
-          MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the data for the Opaque type as a string.");
-  }
-};
+      // clang-format on
+      "Gets a FunctionType from a list of input and result types");
+  c.def_prop_ro(
+      "inputs",
+      [](PyFunctionType &self) {
+        MlirType t = self;
+        nb::list types;
+        for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+             ++i) {
+          types.append(mlirFunctionTypeGetInput(t, i));
+        }
+        return types;
+      },
+      "Returns the list of input types in the FunctionType.");
+  c.def_prop_ro(
+      "results",
+      [](PyFunctionType &self) {
+        nb::list types;
+        for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+             ++i) {
+          types.append(mlirFunctionTypeGetResult(self, i));
+        }
+        return types;
+      },
+      "Returns the list of result types in the FunctionType.");
+}
 
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+void PyOpaqueType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const std::string &dialectNamespace, const std::string &typeData,
+         DefaultingPyMlirContext context) {
+        MlirType type =
+            mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace),
+                              toMlirStringRef(typeData));
+        return PyOpaqueType(context->getRef(), type);
+      },
+      nb::arg("dialect_namespace"), nb::arg("buffer"),
+      nb::arg("context") = nb::none(),
+      "Create an unregistered (opaque) dialect type.");
+  c.def_prop_ro(
+      "dialect_namespace",
+      [](PyOpaqueType &self) {
+        MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+        return nb::str(stringRef.data, stringRef.length);
+      },
+      "Returns the dialect namespace for the Opaque type as a string.");
+  c.def_prop_ro(
+      "data",
+      [](PyOpaqueType &self) {
+        MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+        return nb::str(stringRef.data, stringRef.length);
+      },
+      "Returns the data for the Opaque type as a string.");
+}
 
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 void populateIRTypes(nb::module_ &m) {
   PyIntegerType::bind(m);
   PyFloatType::bind(m);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b2c9380bc1d73..88f58d45cdd75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -9,7 +9,9 @@
 #include "Pass.h"
 #include "Rewrite.h"
 #include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
 #include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 
 namespace nb = nanobind;
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 4a9fb127ee08c..003a06b16daac 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -533,9 +533,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
   SOURCES
     MainModule.cpp
     IRAffine.cpp
-    IRAttributes.cpp
     IRInterfaces.cpp
-    IRTypes.cpp
     Pass.cpp
     Rewrite.cpp
 
@@ -846,8 +844,10 @@ declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
   ADD_TO_PARENT MLIRPythonSources.Core
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
   SOURCES
-    IRCore.cpp
     Globals.cpp
+    IRAttributes.cpp
+    IRCore.cpp
+    IRTypes.cpp
 )
 
 ################################################################################
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 43573cbc305fa..b229c02ccf5e6 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -15,6 +15,7 @@
 #include "mlir-c/IR.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
 #include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "nanobind/nanobind.h"
@@ -47,6 +48,49 @@ struct PyTestType
   }
 };
 
+struct PyTestIntegerRankedTensorType
+    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
+          PyTestIntegerRankedTensorType,
+          mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirRankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::vector<int64_t> shape, unsigned width,
+           mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+               ctx) {
+          MlirAttribute encoding = mlirAttributeGetNull();
+          return PyTestIntegerRankedTensorType(
+              ctx->getRef(),
+              mlirRankedTensorTypeGet(
+                  shape.size(), shape.data(),
+                  mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
+        },
+        nb::arg("shape"), nb::arg("width"),
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+struct PyTestTensorValue
+    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
+          PyTestTensorValue> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAPythonTestTestTensorValue;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirRankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "TestTensorValue";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+  }
+};
+
 class PyTestAttr
     : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
           PyTestAttr> {
@@ -73,18 +117,18 @@ class PyTestAttr
 NB_MODULE(_mlirPythonTestNanobind, m) {
   m.def(
       "register_python_test_dialect",
-      [](MlirContext context, bool load) {
+      [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+             context,
+         bool load) {
         MlirDialectHandle pythonTestDialect =
             mlirGetDialectHandle__python_test__();
-        mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+        mlirDialectHandleRegisterDialect(pythonTestDialect,
+                                         context.get()->get());
         if (load) {
-          mlirDialectHandleLoadDialect(pythonTestDialect, context);
+          mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get());
         }
       },
-      nb::arg("context"), nb::arg("load") = true,
-      // clang-format off
-      nb::sig("def register_python_test_dialect(context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None"));
-  // clang-format on
+      nb::arg("context").none() = nb::none(), nb::arg("load") = true);
 
   m.def(
       "register_dialect",
@@ -100,73 +144,16 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
 
   m.def(
       "test_diagnostics_with_errors_and_notes",
-      [](MlirContext ctx) {
-        mlir::python::CollectDiagnosticsToStringScope handler(ctx);
-        mlirPythonTestEmitDiagnosticWithNote(ctx);
+      [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+             ctx) {
+        mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
+        mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
         throw nb::value_error(handler.takeMessage().c_str());
       },
-      // clang-format off
-      nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
-  // clang-format on
+      nb::arg("context").none() = nb::none());
 
   PyTestAttr::bind(m);
   PyTestType::bind(m);
-
-  auto typeCls =
-      mlir_type_subclass(m, "TestIntegerRankedTensorType",
-                         mlirTypeIsARankedIntegerTensor,
-                         nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                             .attr("RankedTensorType"))
-          .def_classmethod(
-              "get",
-              [](const nb::object &cls, std::vector<int64_t> shape,
-                 unsigned width, MlirContext ctx) {
-                MlirAttribute encoding = mlirAttributeGetNull();
-                return cls(mlirRankedTensorTypeGet(
-                    shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
-                    encoding));
-              },
-              // clang-format off
-              nb::sig("def get(cls: object, shape: collections.abc.Sequence[int], width: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
-              // clang-format on
-              nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
-              nb::arg("context").none() = nb::none());
-
-  assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
-         "TestIntegerRankedTensorType has no static_typeid");
-
-  MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
-
-  nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-      .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
-          mlirRankedTensorTypeID, nb::arg("replace") = true)(
-          nanobind::cpp_function([typeCls](const nb::object &mlirType) {
-            return typeCls.get_class()(mlirType);
-          }));
-
-  auto valueCls = mlir_value_subclass(m, "TestTensorValue",
-                                      mlirTypeIsAPythonTestTestTensorValue)
-                      .def("is_null", [](MlirValue &self) {
-                        return mlirValueIsNull(self);
-                      });
-
-  nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-      .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
-          mlirRankedTensorTypeID)(
-          nanobind::cpp_function([valueCls](const nb::object &valueObj) {
-            std::optional<nb::object> capsule =
-                mlirApiObjectToCapsule(valueObj);
-            assert(capsule.has_value() && "capsule is not null");
-            MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
-            MlirType t = mlirValueGetType(v);
-            // This is hyper-specific in order to exercise/test registering a
-            // value caster from cpp (but only for a single test case; see
-            // testTensorValue python_test.py).
-            if (mlirShapedTypeHasStaticShape(t) &&
-                mlirShapedTypeGetDimSize(t, 0) == 1 &&
-                mlirShapedTypeGetDimSize(t, 1) == 2 &&
-                mlirShapedTypeGetDimSize(t, 2) == 3)
-              return valueCls.get_class()(valueObj);
-            return valueObj;
-          }));
-}
+  PyTestIntegerRankedTensorType::bind(m);
+  PyTestTensorValue::bind(m);
+}
\ No newline at end of file



More information about the llvm-branch-commits mailing list