[Mlir-commits] [mlir] 4669ea3 - [mlir] Add initial Python bindings for DenseInt/FPElementsAttr
Alex Zinenko
llvmlistbot at llvm.org
Mon Nov 9 06:24:03 PST 2020
Author: Alex Zinenko
Date: 2020-11-09T15:23:54+01:00
New Revision: 4669ea3bd8cfce9fca22f3b18abc447c1d42f82a
URL: https://github.com/llvm/llvm-project/commit/4669ea3bd8cfce9fca22f3b18abc447c1d42f82a
DIFF: https://github.com/llvm/llvm-project/commit/4669ea3bd8cfce9fca22f3b18abc447c1d42f82a.diff
LOG: [mlir] Add initial Python bindings for DenseInt/FPElementsAttr
Enumerating elements in these classes is necessary to enable custom
operand accessors for variadic operands.
Depends On D90919
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D90923
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8f71181b385d..2e02d775de3d 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1621,11 +1621,14 @@ class PyDenseElementsAttribute
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}
+ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
+
static void bindDerived(ClassTy &c) {
- c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
- py::arg("array"), py::arg("signless") = true,
- py::arg("context") = py::none(),
- "Gets from a buffer or ndarray")
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+ .def_static("get", PyDenseElementsAttribute::getFromBuffer,
+ py::arg("array"), py::arg("signless") = true,
+ py::arg("context") = py::none(),
+ "Gets from a buffer or ndarray")
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
@@ -1651,6 +1654,101 @@ class PyDenseElementsAttribute
}
};
+/// Refinement of the PyDenseElementsAttribute for attributes containing integer
+/// (and boolean) values. Supports element access.
+class PyDenseIntElementsAttribute
+ : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+ static constexpr const char *pyClassName = "DenseIntElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ /// Returns the element at the given linear position. Asserts if the index is
+ /// out of range.
+ py::int_ dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds element");
+ }
+
+ MlirType type = mlirAttributeGetType(attr);
+ type = mlirShapedTypeGetElementType(type);
+ assert(mlirTypeIsAInteger(type) &&
+ "expected integer element type in dense int elements attribute");
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. py::int_ is implicitly constructible
+ // from any C++ integral type and handles bitwidth correctly.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ unsigned width = mlirIntegerTypeGetWidth(type);
+ bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+ if (isUnsigned) {
+ if (width == 1) {
+ return mlirDenseElementsAttrGetBoolValue(attr, pos);
+ }
+ if (width == 32) {
+ return mlirDenseElementsAttrGetUInt32Value(attr, pos);
+ }
+ if (width == 64) {
+ return mlirDenseElementsAttrGetUInt64Value(attr, pos);
+ }
+ } else {
+ if (width == 1) {
+ return mlirDenseElementsAttrGetBoolValue(attr, pos);
+ }
+ if (width == 32) {
+ return mlirDenseElementsAttrGetInt32Value(attr, pos);
+ }
+ if (width == 64) {
+ return mlirDenseElementsAttrGetInt64Value(attr, pos);
+ }
+ }
+ throw SetPyError(PyExc_TypeError, "Unsupported integer type");
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+ }
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class PyDenseFPElementsAttribute
+ : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+ static constexpr const char *pyClassName = "DenseFPElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ py::float_ dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds element");
+ }
+
+ MlirType type = mlirAttributeGetType(attr);
+ type = mlirShapedTypeGetElementType(type);
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. py::float_ is implicitly constructible
+ // from float and double.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAF32(type)) {
+ return mlirDenseElementsAttrGetFloatValue(attr, pos);
+ }
+ if (mlirTypeIsAF64(type)) {
+ return mlirDenseElementsAttrGetDoubleValue(attr, pos);
+ }
+ throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+ }
+};
+
} // namespace
//------------------------------------------------------------------------------
@@ -2754,6 +2852,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyBoolAttribute::bind(m);
PyStringAttribute::bind(m);
PyDenseElementsAttribute::bind(m);
+ PyDenseIntElementsAttribute::bind(m);
+ PyDenseFPElementsAttribute::bind(m);
//----------------------------------------------------------------------------
// Mapping of PyType.
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 39d69483d1b7..11ad735f054b 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -181,3 +181,63 @@ def testNamedAttr():
print("named:", named)
run(testNamedAttr)
+
+
+# CHECK-LABEL: TEST: testDenseIntAttr
+def testDenseIntAttr():
+ with Context():
+ raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
+ # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
+ print("attr:", raw)
+
+ a = DenseIntElementsAttr(raw)
+ assert len(a) == 6
+
+ # CHECK: 0 1 2 3 4 5
+ for value in a:
+ print(value, end=" ")
+ print()
+
+ # CHECK: i32
+ print(ShapedType(a.type).element_type)
+
+ raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
+ # CHECK: attr: dense<[true, false, true, false]>
+ print("attr:", raw)
+
+ a = DenseIntElementsAttr(raw)
+ assert len(a) == 4
+
+ # CHECK: 1 0 1 0
+ for value in a:
+ print(value, end=" ")
+ print()
+
+ # CHECK: i1
+ print(ShapedType(a.type).element_type)
+
+
+run(testDenseIntAttr)
+
+
+# CHECK-LABEL: TEST: testDenseFPAttr
+def testDenseFPAttr():
+ with Context():
+ raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
+ # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
+
+ print("attr:", raw)
+
+ a = DenseFPElementsAttr(raw)
+ assert len(a) == 4
+
+ # CHECK: 0.0 1.0 2.0 3.0
+ for value in a:
+ print(value, end=" ")
+ print()
+
+ # CHECK: f32
+ print(ShapedType(a.type).element_type)
+
+
+run(testDenseFPAttr)
More information about the Mlir-commits
mailing list