[Mlir-commits] [mlir] 4669ea3 - [mlir] Add initial Python bindings for DenseInt/FPElementsAttr

Alex Zinenko llvmlistbot at llvm.org
Mon Nov 9 06:24:03 PST 2020


Author: Alex Zinenko
Date: 2020-11-09T15:23:54+01:00
New Revision: 4669ea3bd8cfce9fca22f3b18abc447c1d42f82a

URL: https://github.com/llvm/llvm-project/commit/4669ea3bd8cfce9fca22f3b18abc447c1d42f82a
DIFF: https://github.com/llvm/llvm-project/commit/4669ea3bd8cfce9fca22f3b18abc447c1d42f82a.diff

LOG: [mlir] Add initial Python bindings for DenseInt/FPElementsAttr

Enumerating elements in these classes is necessary to enable custom
operand accessors for variadic operands.

Depends On D90919

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D90923

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8f71181b385d..2e02d775de3d 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1621,11 +1621,14 @@ class PyDenseElementsAttribute
     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
   }
 
+  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
+
   static void bindDerived(ClassTy &c) {
-    c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
-                 py::arg("array"), py::arg("signless") = true,
-                 py::arg("context") = py::none(),
-                 "Gets from a buffer or ndarray")
+    c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+        .def_static("get", PyDenseElementsAttribute::getFromBuffer,
+                    py::arg("array"), py::arg("signless") = true,
+                    py::arg("context") = py::none(),
+                    "Gets from a buffer or ndarray")
         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
                     py::arg("shaped_type"), py::arg("element_attr"),
                     "Gets a DenseElementsAttr where all values are the same")
@@ -1651,6 +1654,101 @@ class PyDenseElementsAttribute
   }
 };
 
+/// Refinement of the PyDenseElementsAttribute for attributes containing integer
+/// (and boolean) values. Supports element access.
+class PyDenseIntElementsAttribute
+    : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+                                 PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+  static constexpr const char *pyClassName = "DenseIntElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  /// Returns the element at the given linear position. Asserts if the index is
+  /// out of range.
+  py::int_ dunderGetItem(intptr_t pos) {
+    if (pos < 0 || pos >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds element");
+    }
+
+    MlirType type = mlirAttributeGetType(attr);
+    type = mlirShapedTypeGetElementType(type);
+    assert(mlirTypeIsAInteger(type) &&
+           "expected integer element type in dense int elements attribute");
+    // Dispatch element extraction to an appropriate C function based on the
+    // elemental type of the attribute. py::int_ is implicitly constructible
+    // from any C++ integral type and handles bitwidth correctly.
+    // TODO: consider caching the type properties in the constructor to avoid
+    // querying them on each element access.
+    unsigned width = mlirIntegerTypeGetWidth(type);
+    bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+    if (isUnsigned) {
+      if (width == 1) {
+        return mlirDenseElementsAttrGetBoolValue(attr, pos);
+      }
+      if (width == 32) {
+        return mlirDenseElementsAttrGetUInt32Value(attr, pos);
+      }
+      if (width == 64) {
+        return mlirDenseElementsAttrGetUInt64Value(attr, pos);
+      }
+    } else {
+      if (width == 1) {
+        return mlirDenseElementsAttrGetBoolValue(attr, pos);
+      }
+      if (width == 32) {
+        return mlirDenseElementsAttrGetInt32Value(attr, pos);
+      }
+      if (width == 64) {
+        return mlirDenseElementsAttrGetInt64Value(attr, pos);
+      }
+    }
+    throw SetPyError(PyExc_TypeError, "Unsupported integer type");
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+  }
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class PyDenseFPElementsAttribute
+    : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+                                 PyDenseElementsAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+  static constexpr const char *pyClassName = "DenseFPElementsAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  py::float_ dunderGetItem(intptr_t pos) {
+    if (pos < 0 || pos >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds element");
+    }
+
+    MlirType type = mlirAttributeGetType(attr);
+    type = mlirShapedTypeGetElementType(type);
+    // Dispatch element extraction to an appropriate C function based on the
+    // elemental type of the attribute. py::float_ is implicitly constructible
+    // from float and double.
+    // TODO: consider caching the type properties in the constructor to avoid
+    // querying them on each element access.
+    if (mlirTypeIsAF32(type)) {
+      return mlirDenseElementsAttrGetFloatValue(attr, pos);
+    }
+    if (mlirTypeIsAF64(type)) {
+      return mlirDenseElementsAttrGetDoubleValue(attr, pos);
+    }
+    throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+  }
+};
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -2754,6 +2852,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyBoolAttribute::bind(m);
   PyStringAttribute::bind(m);
   PyDenseElementsAttribute::bind(m);
+  PyDenseIntElementsAttribute::bind(m);
+  PyDenseFPElementsAttribute::bind(m);
 
   //----------------------------------------------------------------------------
   // Mapping of PyType.

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 39d69483d1b7..11ad735f054b 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -181,3 +181,63 @@ def testNamedAttr():
     print("named:", named)
 
 run(testNamedAttr)
+
+
+# CHECK-LABEL: TEST: testDenseIntAttr
+def testDenseIntAttr():
+  with Context():
+    raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
+    # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
+    print("attr:", raw)
+
+    a = DenseIntElementsAttr(raw)
+    assert len(a) == 6
+
+    # CHECK: 0 1 2 3 4 5
+    for value in a:
+      print(value, end=" ")
+    print()
+
+    # CHECK: i32
+    print(ShapedType(a.type).element_type)
+
+    raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
+    # CHECK: attr: dense<[true, false, true, false]>
+    print("attr:", raw)
+
+    a = DenseIntElementsAttr(raw)
+    assert len(a) == 4
+
+    # CHECK: 1 0 1 0
+    for value in a:
+      print(value, end=" ")
+    print()
+
+    # CHECK: i1
+    print(ShapedType(a.type).element_type)
+
+
+run(testDenseIntAttr)
+
+
+# CHECK-LABEL: TEST: testDenseFPAttr
+def testDenseFPAttr():
+  with Context():
+    raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
+    # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
+
+    print("attr:", raw)
+
+    a = DenseFPElementsAttr(raw)
+    assert len(a) == 4
+
+    # CHECK: 0.0 1.0 2.0 3.0
+    for value in a:
+      print(value, end=" ")
+    print()
+
+    # CHECK: f32
+    print(ShapedType(a.type).element_type)
+
+
+run(testDenseFPAttr)


        


More information about the Mlir-commits mailing list