[llvm] [mlir] [MLIR][Python][NFC] move Py* types (PR #155719)

Rolf Morel via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 28 02:56:36 PDT 2025


================
@@ -294,1012 +195,879 @@ static T pyTryCast(nb::handle object) {
   }
 }
 
-/// A python-wrapped dense array attribute with an element type and a derived
-/// implementation class.
-template <typename EltTy, typename DerivedT>
-class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
-public:
-  using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
-
-  /// Iterator over the integer elements of a dense array.
-  class PyDenseArrayIterator {
-  public:
-    PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
-    /// Return a copy of the iterator.
-    PyDenseArrayIterator dunderIter() { return *this; }
-
-    /// Return the next element.
-    EltTy dunderNext() {
-      // Throw if the index has reached the end.
-      if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
-        throw nb::stop_iteration();
-      return DerivedT::getElement(attr.get(), nextIndex++);
-    }
-
-    /// Bind the iterator class.
-    static void bind(nb::module_ &m) {
-      nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
-          .def("__iter__", &PyDenseArrayIterator::dunderIter)
-          .def("__next__", &PyDenseArrayIterator::dunderNext);
-    }
-
-  private:
-    /// The referenced dense array attribute.
-    PyAttribute attr;
-    /// The next index to read.
-    int nextIndex = 0;
-  };
-
-  /// Get the element at the given index.
-  EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
-
-  /// Bind the attribute class.
-  static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
-    // Bind the constructor.
-    if constexpr (std::is_same_v<EltTy, bool>) {
-      c.def_static(
-          "get",
-          [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
-            std::vector<bool> values;
-            for (nb::handle py_value : py_values) {
-              int is_true = PyObject_IsTrue(py_value.ptr());
-              if (is_true < 0) {
-                throw nb::python_error();
-              }
-              values.push_back(is_true);
-            }
-            return getAttribute(values, ctx->getRef());
-          },
-          nb::arg("values"), nb::arg("context").none() = nb::none(),
-          "Gets a uniqued dense array attribute");
-    } else {
-      c.def_static(
-          "get",
-          [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
-            return getAttribute(values, ctx->getRef());
-          },
-          nb::arg("values"), nb::arg("context").none() = nb::none(),
-          "Gets a uniqued dense array attribute");
-    }
-    // Bind the array methods.
-    c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
-      if (i >= mlirDenseArrayGetNumElements(arr))
-        throw nb::index_error("DenseArray index out of range");
-      return arr.getItem(i);
-    });
-    c.def("__len__", [](const DerivedT &arr) {
-      return mlirDenseArrayGetNumElements(arr);
-    });
-    c.def("__iter__",
-          [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
-    c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
-      std::vector<EltTy> values;
-      intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
-      values.reserve(numOldElements + nb::len(extras));
-      for (intptr_t i = 0; i < numOldElements; ++i)
-        values.push_back(arr.getItem(i));
-      for (nb::handle attr : extras)
-        values.push_back(pyTryCast<EltTy>(attr));
-      return getAttribute(values, arr.getContext());
-    });
-  }
+} // namespace
 
-private:
-  static DerivedT getAttribute(const std::vector<EltTy> &values,
-                               PyMlirContextRef ctx) {
-    if constexpr (std::is_same_v<EltTy, bool>) {
-      std::vector<int> intValues(values.begin(), values.end());
-      MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
-                                                  intValues.data());
-      return DerivedT(ctx, attr);
-    } else {
-      MlirAttribute attr =
-          DerivedT::getAttribute(ctx->get(), values.size(), values.data());
-      return DerivedT(ctx, attr);
-    }
-  }
-};
+namespace mlir::python {
+void PyAffineMapAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyAffineMap &affineMap) {
+        MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+        return PyAffineMapAttribute(affineMap.getContext(), attr);
+      },
+      nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+  c.def_prop_ro("value", mlirAffineMapAttrGetValue,
+                "Returns the value of the AffineMap attribute");
+}
 
-/// Instantiate the python dense array classes.
-struct PyDenseBoolArrayAttribute
-    : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
-  static constexpr auto getAttribute = mlirDenseBoolArrayGet;
-  static constexpr auto getElement = mlirDenseBoolArrayGetElement;
-  static constexpr const char *pyClassName = "DenseBoolArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI8ArrayAttribute
-    : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
-  static constexpr auto getAttribute = mlirDenseI8ArrayGet;
-  static constexpr auto getElement = mlirDenseI8ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI8ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI16ArrayAttribute
-    : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
-  static constexpr auto getAttribute = mlirDenseI16ArrayGet;
-  static constexpr auto getElement = mlirDenseI16ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI16ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI32ArrayAttribute
-    : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
-  static constexpr auto getAttribute = mlirDenseI32ArrayGet;
-  static constexpr auto getElement = mlirDenseI32ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI32ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI64ArrayAttribute
-    : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
-  static constexpr auto getAttribute = mlirDenseI64ArrayGet;
-  static constexpr auto getElement = mlirDenseI64ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseI64ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF32ArrayAttribute
-    : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
-  static constexpr auto getAttribute = mlirDenseF32ArrayGet;
-  static constexpr auto getElement = mlirDenseF32ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseF32ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF64ArrayAttribute
-    : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
-  static constexpr auto getAttribute = mlirDenseF64ArrayGet;
-  static constexpr auto getElement = mlirDenseF64ArrayGetElement;
-  static constexpr const char *pyClassName = "DenseF64ArrayAttr";
-  static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
-  using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
+void PyIntegerSetAttribute::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyIntegerSet &integerSet) {
+        MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+        return PyIntegerSetAttribute(integerSet.getContext(), attr);
+      },
+      nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+}
 
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
-  static constexpr const char *pyClassName = "ArrayAttr";
-  using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirArrayAttrGetTypeID;
-
-  class PyArrayAttributeIterator {
-  public:
-    PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
-    PyArrayAttributeIterator &dunderIter() { return *this; }
-
-    MlirAttribute dunderNext() {
-      // TODO: Throw is an inefficient way to stop iteration.
-      if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
-        throw nb::stop_iteration();
-      return mlirArrayAttrGetElement(attr.get(), nextIndex++);
-    }
+template <typename EltTy, typename DerivedT>
+typename PyDenseArrayAttribute<EltTy, DerivedT>::PyDenseArrayIterator
+PyDenseArrayAttribute<EltTy, DerivedT>::PyDenseArrayIterator::dunderIter() {
+  return *this;
+}
 
-    static void bind(nb::module_ &m) {
-      nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
-          .def("__iter__", &PyArrayAttributeIterator::dunderIter)
-          .def("__next__", &PyArrayAttributeIterator::dunderNext);
-    }
+template <typename EltTy, typename DerivedT>
+EltTy PyDenseArrayAttribute<EltTy,
+                            DerivedT>::PyDenseArrayIterator::dunderNext() {
+  // Throw if the index has reached the end.
+  if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
+    throw nb::stop_iteration();
+  return DerivedT::getElement(attr.get(), nextIndex++);
+}
 
-  private:
-    PyAttribute attr;
-    int nextIndex = 0;
-  };
+template <typename EltTy, typename DerivedT>
+void PyDenseArrayAttribute<EltTy, DerivedT>::PyDenseArrayIterator::bind(
+    nb::module_ &m) {
+  nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
+      .def("__iter__", &PyDenseArrayIterator::dunderIter)
+      .def("__next__", &PyDenseArrayIterator::dunderNext);
+}
 
-  MlirAttribute getItem(intptr_t i) {
-    return mlirArrayAttrGetElement(*this, i);
-  }
+template <typename EltTy, typename DerivedT>
+EltTy PyDenseArrayAttribute<EltTy, DerivedT>::getItem(intptr_t i) {
+  return DerivedT::getElement(*this, i);
+}
 
-  static void bindDerived(ClassTy &c) {
+template <typename EltTy, typename DerivedT>
+void PyDenseArrayAttribute<EltTy, DerivedT>::bindDerived(
+    typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
+  // Bind the constructor.
+  if constexpr (std::is_same_v<EltTy, bool>) {
     c.def_static(
         "get",
-        [](const nb::list &attributes, DefaultingPyMlirContext context) {
-          SmallVector<MlirAttribute> mlirAttributes;
-          mlirAttributes.reserve(nb::len(attributes));
-          for (auto attribute : attributes) {
-            mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
+        [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
+          std::vector<bool> values;
+          for (nb::handle py_value : py_values) {
+            int is_true = PyObject_IsTrue(py_value.ptr());
+            if (is_true < 0) {
+              throw nb::python_error();
+            }
+            values.push_back(is_true);
----------------
rolfmorel wrote:

Random spot check: you're pushing `is_true` here, which I believe is a Boolean value and _not_ the attribute element itself. I think.

https://github.com/llvm/llvm-project/pull/155719


More information about the llvm-commits mailing list