[Mlir-commits] [mlir] 619fd8c - [mlir][python] Add python bindings for DenseArrayAttr

Jeff Niu llvmlistbot at llvm.org
Fri Aug 12 16:44:57 PDT 2022


Author: Jeff Niu
Date: 2022-08-12T19:44:49-04:00
New Revision: 619fd8c2ab505d8f79cbbbe3fd09b02f6640e1b1

URL: https://github.com/llvm/llvm-project/commit/619fd8c2ab505d8f79cbbbe3fd09b02f6640e1b1
DIFF: https://github.com/llvm/llvm-project/commit/619fd8c2ab505d8f79cbbbe3fd09b02f6640e1b1.diff

LOG: [mlir][python] Add python bindings for DenseArrayAttr

This patch adds python bindings for the dense array variants.

Fixes #56975

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D131801

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 62ee31904acb..c75db95b470f 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -296,6 +296,61 @@ mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs);
 /// shaped type and use its sizes to build a multi-dimensional index.
 MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
 
+//===----------------------------------------------------------------------===//
+// Dense array attribute.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given attribute is a dense array attribute.
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr);
+
+/// Create a dense array attribute with the given elements.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx,
+                                                       intptr_t size,
+                                                       int const *values);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx,
+                                                     intptr_t size,
+                                                     int8_t const *values);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx,
+                                                      intptr_t size,
+                                                      int16_t const *values);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx,
+                                                      intptr_t size,
+                                                      int32_t const *values);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx,
+                                                      intptr_t size,
+                                                      int64_t const *values);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx,
+                                                      intptr_t size,
+                                                      float const *values);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx,
+                                                      intptr_t size,
+                                                      double const *values);
+
+/// Get the size of a dense array.
+MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr);
+
+/// Get an element of a dense array.
+MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr,
+                                                     intptr_t pos);
+MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr,
+                                                     intptr_t pos);
+MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr,
+                                                       intptr_t pos);
+MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr,
+                                                       intptr_t pos);
+MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr,
+                                                       intptr_t pos);
+MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr,
+                                                     intptr_t pos);
+MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr,
+                                                      intptr_t pos);
+
 //===----------------------------------------------------------------------===//
 // Dense elements attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 1093d50c8869..d8fc568b73f7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -110,6 +110,161 @@ static T pyTryCast(py::handle object) {
   }
 }
 
+/// A python-wrapped dense array attribute with an element type and a derived
+/// implementation class.
+template <typename EltTy, typename DerivedT>
+class PyDenseArrayAttribute
+    : public PyConcreteAttribute<PyDenseArrayAttribute<EltTy, DerivedT>> {
+public:
+  static constexpr typename PyConcreteAttribute<
+      PyDenseArrayAttribute<EltTy, DerivedT>>::IsAFunctionTy isaFunction =
+      DerivedT::isaFunction;
+  static constexpr const char *pyClassName = DerivedT::pyClassName;
+  using PyConcreteAttribute<
+      PyDenseArrayAttribute<EltTy, DerivedT>>::PyConcreteAttribute;
+
+  /// Iterator over the integer elements of a dense array.
+  class PyDenseArrayIterator {
+  public:
+    PyDenseArrayIterator(PyAttribute attr) : attr(attr) {}
+
+    /// Return a copy of the iterator.
+    PyDenseArrayIterator dunderIter() { return *this; }
+
+    /// Return the next element.
+    EltTy dunderNext() {
+      // Throw if the index has reached the end.
+      if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
+        throw py::stop_iteration();
+      return DerivedT::getElement(attr.get(), nextIndex++);
+    }
+
+    /// Bind the iterator class.
+    static void bind(py::module &m) {
+      py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
+                                       py::module_local())
+          .def("__iter__", &PyDenseArrayIterator::dunderIter)
+          .def("__next__", &PyDenseArrayIterator::dunderNext);
+    }
+
+  private:
+    /// The referenced dense array attribute.
+    PyAttribute attr;
+    /// The next index to read.
+    int nextIndex = 0;
+  };
+
+  /// Get the element at the given index.
+  EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
+
+  /// Bind the attribute class.
+  static void bindDerived(typename PyConcreteAttribute<
+                          PyDenseArrayAttribute<EltTy, DerivedT>>::ClassTy &c) {
+    // Bind the constructor.
+    c.def_static(
+        "get",
+        [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
+          MlirAttribute attr =
+              DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+          return PyDenseArrayAttribute<EltTy, DerivedT>(ctx->getRef(), attr);
+        },
+        py::arg("values"), py::arg("context") = py::none(),
+        "Gets a uniqued dense array attribute");
+    // Bind the array methods.
+    c.def("__getitem__",
+          [](PyDenseArrayAttribute<EltTy, DerivedT> &arr, intptr_t i) {
+            if (i >= mlirDenseArrayGetNumElements(arr))
+              throw py::index_error("DenseArray index out of range");
+            return arr.getItem(i);
+          });
+    c.def("__len__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
+      return mlirDenseArrayGetNumElements(arr);
+    });
+    c.def("__iter__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
+      return PyDenseArrayIterator(arr);
+    });
+    // Bind a concat.
+    c.def("__add__", [](PyDenseArrayAttribute<EltTy, DerivedT> &arr,
+                        py::list extras) {
+      std::vector<EltTy> values;
+      intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
+      values.reserve(numOldElements + py::len(extras));
+      for (intptr_t i = 0; i < numOldElements; ++i)
+        values.push_back(arr.getItem(i));
+      for (py::handle attr : extras)
+        values.push_back(pyTryCast<EltTy>(attr));
+      MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
+                                                  values.size(), values.data());
+      return PyDenseArrayAttribute<EltTy, DerivedT>(arr.getContext(), attr);
+    });
+  }
+};
+
+/// Instantiate the python dense array classes.
+struct PyDenseBoolArrayAttribute
+    : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
+  static constexpr auto getAttribute = mlirDenseBoolArrayGet;
+  static constexpr auto getElement = mlirDenseBoolArrayGetElement;
+  static constexpr const char *pyClassName = "DenseBoolArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI8ArrayAttribute
+    : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
+  static constexpr auto getAttribute = mlirDenseI8ArrayGet;
+  static constexpr auto getElement = mlirDenseI8ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI8ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI16ArrayAttribute
+    : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
+  static constexpr auto getAttribute = mlirDenseI16ArrayGet;
+  static constexpr auto getElement = mlirDenseI16ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI16ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI32ArrayAttribute
+    : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
+  static constexpr auto getAttribute = mlirDenseI32ArrayGet;
+  static constexpr auto getElement = mlirDenseI32ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI32ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI64ArrayAttribute
+    : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
+  static constexpr auto getAttribute = mlirDenseI64ArrayGet;
+  static constexpr auto getElement = mlirDenseI64ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseI64ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF32ArrayAttribute
+    : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
+  static constexpr auto getAttribute = mlirDenseF32ArrayGet;
+  static constexpr auto getElement = mlirDenseF32ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseF32ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF64ArrayAttribute
+    : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
+  static constexpr auto getAttribute = mlirDenseF64ArrayGet;
+  static constexpr auto getElement = mlirDenseF64ArrayGetElement;
+  static constexpr const char *pyClassName = "DenseF64ArrayAttr";
+  static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
+  using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
@@ -891,6 +1046,22 @@ class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
 
 void mlir::python::populateIRAttributes(py::module &m) {
   PyAffineMapAttribute::bind(m);
+
+  PyDenseBoolArrayAttribute::bind(m);
+  PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
+  PyDenseI8ArrayAttribute::bind(m);
+  PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
+  PyDenseI16ArrayAttribute::bind(m);
+  PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
+  PyDenseI32ArrayAttribute::bind(m);
+  PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
+  PyDenseI64ArrayAttribute::bind(m);
+  PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
+  PyDenseF32ArrayAttribute::bind(m);
+  PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
+  PyDenseF64ArrayAttribute::bind(m);
+  PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
+
   PyArrayAttribute::bind(m);
   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
   PyBoolAttribute::bind(m);

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index afab9458d443..c50096bb1c1b 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -311,6 +311,106 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
   return unwrap(attr).cast<ElementsAttr>().getNumElements();
 }
 
+//===----------------------------------------------------------------------===//
+// Dense array attribute.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// IsA support.
+
+bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
+  return unwrap(attr).isa<DenseBoolArrayAttr>();
+}
+bool mlirAttributeIsADenseI8Array(MlirAttribute attr) {
+  return unwrap(attr).isa<DenseI8ArrayAttr>();
+}
+bool mlirAttributeIsADenseI16Array(MlirAttribute attr) {
+  return unwrap(attr).isa<DenseI16ArrayAttr>();
+}
+bool mlirAttributeIsADenseI32Array(MlirAttribute attr) {
+  return unwrap(attr).isa<DenseI32ArrayAttr>();
+}
+bool mlirAttributeIsADenseI64Array(MlirAttribute attr) {
+  return unwrap(attr).isa<DenseI64ArrayAttr>();
+}
+bool mlirAttributeIsADenseF32Array(MlirAttribute attr) {
+  return unwrap(attr).isa<DenseF32ArrayAttr>();
+}
+bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
+  return unwrap(attr).isa<DenseF64ArrayAttr>();
+}
+
+//===----------------------------------------------------------------------===//
+// Constructors.
+
+MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
+                                    int const *values) {
+  SmallVector<bool, 4> elements(values, values + size);
+  return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements));
+}
+MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size,
+                                  int8_t const *values) {
+  return wrap(
+      DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size)));
+}
+MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size,
+                                   int16_t const *values) {
+  return wrap(
+      DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size)));
+}
+MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size,
+                                   int32_t const *values) {
+  return wrap(
+      DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size)));
+}
+MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size,
+                                   int64_t const *values) {
+  return wrap(
+      DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size)));
+}
+MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size,
+                                   float const *values) {
+  return wrap(
+      DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size)));
+}
+MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
+                                   double const *values) {
+  return wrap(
+      DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size)));
+}
+
+//===----------------------------------------------------------------------===//
+// Accessors.
+
+intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
+  return unwrap(attr).cast<DenseArrayBaseAttr>().size();
+}
+
+//===----------------------------------------------------------------------===//
+// Indexed accessors.
+
+bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseBoolArrayAttr>()[pos];
+}
+int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseI8ArrayAttr>()[pos];
+}
+int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseI16ArrayAttr>()[pos];
+}
+int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseI32ArrayAttr>()[pos];
+}
+int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseI64ArrayAttr>()[pos];
+}
+float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseF32ArrayAttr>()[pos];
+}
+double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseF64ArrayAttr>()[pos];
+}
+
 //===----------------------------------------------------------------------===//
 // Dense elements attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 79921dec640b..6f1764b4f987 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1186,6 +1186,40 @@ int printBuiltinAttributes(MlirContext ctx) {
   mlirAttributeDump(sparseAttr);
   // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32>
 
+  MlirAttribute boolArray = mlirDenseBoolArrayGet(ctx, 2, bools);
+  MlirAttribute int8Array = mlirDenseI8ArrayGet(ctx, 2, ints8);
+  MlirAttribute int16Array = mlirDenseI16ArrayGet(ctx, 2, ints16);
+  MlirAttribute int32Array = mlirDenseI32ArrayGet(ctx, 2, ints32);
+  MlirAttribute int64Array = mlirDenseI64ArrayGet(ctx, 2, ints64);
+  MlirAttribute floatArray = mlirDenseF32ArrayGet(ctx, 2, floats);
+  MlirAttribute doubleArray = mlirDenseF64ArrayGet(ctx, 2, doubles);
+  if (!mlirAttributeIsADenseBoolArray(boolArray) ||
+      !mlirAttributeIsADenseI8Array(int8Array) ||
+      !mlirAttributeIsADenseI16Array(int16Array) ||
+      !mlirAttributeIsADenseI32Array(int32Array) ||
+      !mlirAttributeIsADenseI64Array(int64Array) ||
+      !mlirAttributeIsADenseF32Array(floatArray) ||
+      !mlirAttributeIsADenseF64Array(doubleArray))
+    return 19;
+
+  if (mlirDenseArrayGetNumElements(boolArray) != 2 ||
+      mlirDenseArrayGetNumElements(int8Array) != 2 ||
+      mlirDenseArrayGetNumElements(int16Array) != 2 ||
+      mlirDenseArrayGetNumElements(int32Array) != 2 ||
+      mlirDenseArrayGetNumElements(int64Array) != 2 ||
+      mlirDenseArrayGetNumElements(floatArray) != 2 ||
+      mlirDenseArrayGetNumElements(doubleArray) != 2)
+    return 20;
+
+  if (mlirDenseBoolArrayGetElement(boolArray, 1) != 1 ||
+      mlirDenseI8ArrayGetElement(int8Array, 1) != 1 ||
+      mlirDenseI16ArrayGetElement(int16Array, 1) != 1 ||
+      mlirDenseI32ArrayGetElement(int32Array, 1) != 1 ||
+      mlirDenseI64ArrayGetElement(int64Array, 1) != 1 ||
+      fabsf(mlirDenseF32ArrayGetElement(floatArray, 1) - 1.0f) > 1E-6f ||
+      fabs(mlirDenseF64ArrayGetElement(doubleArray, 1) - 1.0) > 1E-6)
+    return 21;
+
   return 0;
 }
 

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 97ebdd323fe4..a958abfc9e75 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -1,8 +1,10 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
+
 from mlir.ir import *
 
+
 def run(f):
   print("\nTEST:", f.__name__)
   f()
@@ -319,6 +321,29 @@ def testDenseIntAttr():
     print(ShapedType(a.type).element_type)
 
 
+ at run
+def testDenseArrayGetItem():
+  def print_item(AttrClass, attr_asm):
+    attr = AttrClass(Attribute.parse(attr_asm))
+    print(f"{len(attr)}: {attr[0]}, {attr[1]}")
+
+  with Context():
+    # CHECK: 2: 0, 1
+    print_item(DenseBoolArrayAttr, "array<i1: false, true>")
+    # CHECK: 2: 2, 3
+    print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
+    # CHECK: 2: 4, 5
+    print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
+    # CHECK: 2: 6, 7
+    print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
+    # CHECK: 2: 8, 9
+    print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
+    # CHECK: 2: 1.{{0+}}, 2.{{0+}}
+    print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
+    # CHECK: 2: 3.{{0+}}, 4.{{0+}}
+    print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
+
+
 # CHECK-LABEL: TEST: testDenseIntAttrGetItem
 @run
 def testDenseIntAttrGetItem():


        


More information about the Mlir-commits mailing list