[Mlir-commits] [mlir] 8dcb672 - [mlir][python] Make DenseBoolArrayAttr.get work with list of bools.

Ingo Müller llvmlistbot at llvm.org
Mon Aug 28 08:15:13 PDT 2023


Author: Ingo Müller
Date: 2023-08-28T15:15:08Z
New Revision: 8dcb67225b2ce871b54f7a0f172b58f15f05f7fa

URL: https://github.com/llvm/llvm-project/commit/8dcb67225b2ce871b54f7a0f172b58f15f05f7fa
DIFF: https://github.com/llvm/llvm-project/commit/8dcb67225b2ce871b54f7a0f172b58f15f05f7fa.diff

LOG: [mlir][python] Make DenseBoolArrayAttr.get work with list of bools.

This patch makes the getter function of `DenseBoolArrayAttr` work more
intuitively. Until now, it was implemented with a `std::vector<int>`
argument, which works in the typical situation where you call the pybind
function with a list of Python bools (like `[True, False]`). However, it
does *not* work if the elements of the list have to be cast to Bool
before (and that is the default behavior for lists of all other types).
The patch thus changes the signature to `std::vector<bool>`, which helps
pybind to make the function behave as expected for bools. The tests now
also contain a case where such a cast is happening. This also makes the
conversion of `DenseBoolArrayAttr` back to Python more intuitive:
instead of converting to `0` and `1`, the elements are now converted to
`False` and `True`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158973

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 75d743f3a3962a..50cfc0624fccfc 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -162,9 +162,7 @@ class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
     c.def_static(
         "get",
         [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
-          MlirAttribute attr =
-              DerivedT::getAttribute(ctx->get(), values.size(), values.data());
-          return DerivedT(ctx->getRef(), attr);
+          return getAttribute(values, ctx->getRef());
         },
         py::arg("values"), py::arg("context") = py::none(),
         "Gets a uniqued dense array attribute");
@@ -187,16 +185,29 @@ class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
         values.push_back(arr.getItem(i));
       for (py::handle attr : extras)
         values.push_back(pyTryCast<EltTy>(attr));
-      MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
-                                                  values.size(), values.data());
-      return DerivedT(arr.getContext(), attr);
+      return getAttribute(values, arr.getContext());
     });
   }
+
+private:
+  static DerivedT getAttribute(const std::vector<EltTy> &values,
+                               PyMlirContextRef ctx) {
+    if constexpr (std::is_same_v<EltTy, bool>) {
+      std::vector<int> intValues(values.begin(), values.end());
+      MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
+                                                  intValues.data());
+      return DerivedT(ctx, attr);
+    } else {
+      MlirAttribute attr =
+          DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+      return DerivedT(ctx, attr);
+    }
+  }
 };
 
 /// Instantiate the python dense array classes.
 struct PyDenseBoolArrayAttribute
-    : public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
+    : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
   static constexpr auto getAttribute = mlirDenseBoolArrayGet;
   static constexpr auto getElement = mlirDenseBoolArrayGetElement;

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index d986cac17dd765..1a2ed7d6642b88 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -344,7 +344,7 @@ def print_item(attr_asm):
         print(f"{len(attr)}: {attr[0]}, {attr[1]}")
 
     with Context():
-        # CHECK: 2: 0, 1
+        # CHECK: 2: False, True
         print_item("array<i1: false, true>")
         # CHECK: 2: 2, 3
         print_item("array<i8: 2, 3>")
@@ -359,6 +359,13 @@ def print_item(attr_asm):
         # CHECK: 2: 3.{{0+}}, 4.{{0+}}
         print_item("array<f64: 3.0, 4.0>")
 
+        class MyBool:
+            def __bool__(self):
+                return True
+
+        # CHECK: myboolarray: array<i1: true>
+        print("myboolarray:", DenseBoolArrayAttr.get([MyBool()]))
+
 
 # CHECK-LABEL: TEST: testDenseIntAttrGetItem
 @run


        


More information about the Mlir-commits mailing list