[Mlir-commits] [mlir] 8dcb672 - [mlir][python] Make DenseBoolArrayAttr.get work with list of bools.
Ingo Müller
llvmlistbot at llvm.org
Mon Aug 28 08:15:13 PDT 2023
Author: Ingo Müller
Date: 2023-08-28T15:15:08Z
New Revision: 8dcb67225b2ce871b54f7a0f172b58f15f05f7fa
URL: https://github.com/llvm/llvm-project/commit/8dcb67225b2ce871b54f7a0f172b58f15f05f7fa
DIFF: https://github.com/llvm/llvm-project/commit/8dcb67225b2ce871b54f7a0f172b58f15f05f7fa.diff
LOG: [mlir][python] Make DenseBoolArrayAttr.get work with list of bools.
This patch makes the getter function of `DenseBoolArrayAttr` work more
intuitively. Until now, it was implemented with a `std::vector<int>`
argument, which works in the typical situation where you call the pybind
function with a list of Python bools (like `[True, False]`). However, it
does *not* work if the elements of the list have to be cast to Bool
before (and that is the default behavior for lists of all other types).
The patch thus changes the signature to `std::vector<bool>`, which helps
pybind to make the function behave as expected for bools. The tests now
also contain a case where such a cast is happening. This also makes the
conversion of `DenseBoolArrayAttr` back to Python more intuitive:
instead of converting to `0` and `1`, the elements are now converted to
`False` and `True`.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158973
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/ir/attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 75d743f3a3962a..50cfc0624fccfc 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -162,9 +162,7 @@ class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
c.def_static(
"get",
[](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
- MlirAttribute attr =
- DerivedT::getAttribute(ctx->get(), values.size(), values.data());
- return DerivedT(ctx->getRef(), attr);
+ return getAttribute(values, ctx->getRef());
},
py::arg("values"), py::arg("context") = py::none(),
"Gets a uniqued dense array attribute");
@@ -187,16 +185,29 @@ class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
values.push_back(arr.getItem(i));
for (py::handle attr : extras)
values.push_back(pyTryCast<EltTy>(attr));
- MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
- values.size(), values.data());
- return DerivedT(arr.getContext(), attr);
+ return getAttribute(values, arr.getContext());
});
}
+
+private:
+ static DerivedT getAttribute(const std::vector<EltTy> &values,
+ PyMlirContextRef ctx) {
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ std::vector<int> intValues(values.begin(), values.end());
+ MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
+ intValues.data());
+ return DerivedT(ctx, attr);
+ } else {
+ MlirAttribute attr =
+ DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+ return DerivedT(ctx, attr);
+ }
+ }
};
/// Instantiate the python dense array classes.
struct PyDenseBoolArrayAttribute
- : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
+ : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
static constexpr auto getAttribute = mlirDenseBoolArrayGet;
static constexpr auto getElement = mlirDenseBoolArrayGetElement;
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index d986cac17dd765..1a2ed7d6642b88 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -344,7 +344,7 @@ def print_item(attr_asm):
print(f"{len(attr)}: {attr[0]}, {attr[1]}")
with Context():
- # CHECK: 2: 0, 1
+ # CHECK: 2: False, True
print_item("array<i1: false, true>")
# CHECK: 2: 2, 3
print_item("array<i8: 2, 3>")
@@ -359,6 +359,13 @@ def print_item(attr_asm):
# CHECK: 2: 3.{{0+}}, 4.{{0+}}
print_item("array<f64: 3.0, 4.0>")
+ class MyBool:
+ def __bool__(self):
+ return True
+
+ # CHECK: myboolarray: array<i1: true>
+ print("myboolarray:", DenseBoolArrayAttr.get([MyBool()]))
+
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
@run
More information about the Mlir-commits
mailing list