[Mlir-commits] [mlir] 2a5d497 - [mlir][python] Add __{bool, float, int, str}__ to bindings of attributes.

Ingo Müller llvmlistbot at llvm.org
Tue Aug 29 07:53:32 PDT 2023


Author: Ingo Müller
Date: 2023-08-29T14:53:26Z
New Revision: 2a5d497494c24425e99655b85e2277dd3f15a400

URL: https://github.com/llvm/llvm-project/commit/2a5d497494c24425e99655b85e2277dd3f15a400
DIFF: https://github.com/llvm/llvm-project/commit/2a5d497494c24425e99655b85e2277dd3f15a400.diff

LOG: [mlir][python] Add __{bool,float,int,str}__ to bindings of attributes.

This allows to use Python's `bool(.)`, `float(.)`, `int(.)`, and
`str(.)` to convert pybound attributes to the corresponding native
Python types. In particular, pybind11 uses these functions to
automatically cast objects to the corresponding primitive types wherever
they are required by pybound functions, e.g., arguments are converted to
Python's `int` if the C++ signature requires a C++ `int`. With this
patch, pybound attributes can by used wherever the corresponding native
types are expected. New tests show-case this behavior in the
constructors of `Dense*ArrayAttr`.

Note that this changes the output of Python's `str` on `StringAttr` from
`"hello"` to `hello`. Arguably, this is still in line with `str`s goal
of producing a readable interpretation of the value, even if it is now
not unambiously a string anymore (`print(ir.Attribute.parse('"42"'))`
now outputs `42`). However, this is consistent with instances of
Python's `str` (`print("42")` outputs `42`), and `repr` still provides
an unambigous representation if one is required.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158974

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 50cfc0624fccfc..6531d6276ae67f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -389,12 +389,10 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
         },
         py::arg("value"), py::arg("context") = py::none(),
         "Gets an uniqued float point attribute associated to a f64 type");
-    c.def_property_readonly(
-        "value",
-        [](PyFloatAttribute &self) {
-          return mlirFloatAttrGetValueDouble(self);
-        },
-        "Returns the value of the float point attribute");
+    c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
+                            "Returns the value of the float attribute");
+    c.def("__float__", mlirFloatAttrGetValueDouble,
+          "Converts the value of the float attribute to a Python float");
   }
 };
 
@@ -414,22 +412,25 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
         },
         py::arg("type"), py::arg("value"),
         "Gets an uniqued integer attribute associated to a type");
-    c.def_property_readonly(
-        "value",
-        [](PyIntegerAttribute &self) -> py::int_ {
-          MlirType type = mlirAttributeGetType(self);
-          if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
-            return mlirIntegerAttrGetValueInt(self);
-          if (mlirIntegerTypeIsSigned(type))
-            return mlirIntegerAttrGetValueSInt(self);
-          return mlirIntegerAttrGetValueUInt(self);
-        },
-        "Returns the value of the integer attribute");
+    c.def_property_readonly("value", toPyInt,
+                            "Returns the value of the integer attribute");
+    c.def("__int__", toPyInt,
+          "Converts the value of the integer attribute to a Python int");
     c.def_property_readonly_static("static_typeid",
                                    [](py::object & /*class*/) -> MlirTypeID {
                                      return mlirIntegerAttrGetTypeID();
                                    });
   }
+
+private:
+  static py::int_ toPyInt(PyIntegerAttribute &self) {
+    MlirType type = mlirAttributeGetType(self);
+    if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+      return mlirIntegerAttrGetValueInt(self);
+    if (mlirIntegerTypeIsSigned(type))
+      return mlirIntegerAttrGetValueSInt(self);
+    return mlirIntegerAttrGetValueUInt(self);
+  }
 };
 
 /// Bool Attribute subclass - BoolAttr.
@@ -448,10 +449,10 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
         },
         py::arg("value"), py::arg("context") = py::none(),
         "Gets an uniqued bool attribute");
-    c.def_property_readonly(
-        "value",
-        [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
-        "Returns the value of the bool attribute");
+    c.def_property_readonly("value", mlirBoolAttrGetValue,
+                            "Returns the value of the bool attribute");
+    c.def("__bool__", mlirBoolAttrGetValue,
+          "Converts the value of the bool attribute to a Python bool");
   }
 };
 
@@ -595,13 +596,8 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
         },
         py::arg("type"), py::arg("value"),
         "Gets a uniqued string attribute associated to a type");
-    c.def_property_readonly(
-        "value",
-        [](PyStringAttribute &self) {
-          MlirStringRef stringRef = mlirStringAttrGetValue(self);
-          return py::str(stringRef.data, stringRef.length);
-        },
-        "Returns the value of the string attribute");
+    c.def_property_readonly("value", toPyStr,
+                            "Returns the value of the string attribute");
     c.def_property_readonly(
         "value_bytes",
         [](PyStringAttribute &self) {
@@ -609,6 +605,14 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
           return py::bytes(stringRef.data, stringRef.length);
         },
         "Returns the value of the string attribute as `bytes`");
+    c.def("__str__", toPyStr,
+          "Converts the value of the string attribute to a Python str");
+  }
+
+private:
+  static py::str toPyStr(PyStringAttribute &self) {
+    MlirStringRef stringRef = mlirStringAttrGetValue(self);
+    return py::str(stringRef.data, stringRef.length);
   }
 };
 

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 1a2ed7d6642b88..8c0c01d253baa7 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -21,7 +21,7 @@ def testParsePrint():
     assert t.context is ctx
     ctx = None
     gc.collect()
-    # CHECK: "hello"
+    # CHECK: hello
     print(str(t))
     # CHECK: StringAttr("hello")
     print(repr(t))
@@ -169,6 +169,8 @@ def testFloatAttr():
         fattr = FloatAttr(Attribute.parse("42.0 : f32"))
         # CHECK: fattr value: 42.0
         print("fattr value:", fattr.value)
+        # CHECK: fattr float: 42.0 <class 'float'>
+        print("fattr float:", float(fattr), type(float(fattr)))
 
         # Test factory methods.
         # CHECK: default_get: 4.200000e+01 : f32
@@ -196,15 +198,23 @@ def testIntegerAttr():
         print("i_attr value:", i_attr.value)
         # CHECK: i_attr type: i64
         print("i_attr type:", i_attr.type)
+        # CHECK: i_attr int: 42 <class 'int'>
+        print("i_attr int:", int(i_attr), type(int(i_attr)))
         si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
         # CHECK: si_attr value: -1
         print("si_attr value:", si_attr.value)
         ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
+        # CHECK: i_attr int: -1 <class 'int'>
+        print("si_attr int:", int(si_attr), type(int(si_attr)))
         # CHECK: ui_attr value: 255
         print("ui_attr value:", ui_attr.value)
+        # CHECK: i_attr int: 255 <class 'int'>
+        print("ui_attr int:", int(ui_attr), type(int(ui_attr)))
         idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
         # CHECK: idx_attr value: -1
         print("idx_attr value:", idx_attr.value)
+        # CHECK: idx_attr int: -1 <class 'int'>
+        print("idx_attr int:", int(idx_attr), type(int(idx_attr)))
 
         # Test factory methods.
         # CHECK: default_get: 42 : i32
@@ -218,6 +228,8 @@ def testBoolAttr():
         battr = BoolAttr(Attribute.parse("true"))
         # CHECK: iattr value: True
         print("iattr value:", battr.value)
+        # CHECK: iattr bool: True <class 'bool'>
+        print("iattr bool:", bool(battr), type(bool(battr)))
 
         # Test factory methods.
         # CHECK: default_get: true
@@ -278,14 +290,25 @@ def testStringAttr():
         sattr = StringAttr(Attribute.parse('"stringattr"'))
         # CHECK: sattr value: stringattr
         print("sattr value:", sattr.value)
-        # CHECK: sattr value: b'stringattr'
-        print("sattr value:", sattr.value_bytes)
+        # CHECK: sattr value_bytes: b'stringattr'
+        print("sattr value_bytes:", sattr.value_bytes)
+        # CHECK: sattr str: stringattr
+        print("sattr str:", str(sattr))
+
+        typed_sattr = StringAttr(Attribute.parse('"stringattr" : i32'))
+        # CHECK: typed_sattr value: stringattr
+        print("typed_sattr value:", typed_sattr.value)
+        # CHECK: typed_sattr str: stringattr
+        print("typed_sattr str:", str(typed_sattr))
 
         # Test factory methods.
-        # CHECK: default_get: "foobar"
-        print("default_get:", StringAttr.get("foobar"))
-        # CHECK: typed_get: "12345" : i32
-        print("typed_get:", StringAttr.get_typed(IntegerType.get_signless(32), "12345"))
+        # CHECK: default_get: StringAttr("foobar")
+        print("default_get:", repr(StringAttr.get("foobar")))
+        # CHECK: typed_get: StringAttr("12345" : i32)
+        print(
+            "typed_get:",
+            repr(StringAttr.get_typed(IntegerType.get_signless(32), "12345")),
+        )
 
 
 # CHECK-LABEL: TEST: testNamedAttr
@@ -294,8 +317,8 @@ def testNamedAttr():
     with Context():
         a = Attribute.parse('"stringattr"')
         named = a.get_named("foobar")  # Note: under the small object threshold
-        # CHECK: attr: "stringattr"
-        print("attr:", named.attr)
+        # CHECK: attr: StringAttr("stringattr")
+        print("attr:", repr(named.attr))
         # CHECK: name: foobar
         print("name:", named.name)
         # CHECK: named: NamedAttribute(foobar="stringattr")
@@ -367,6 +390,65 @@ def __bool__(self):
         print("myboolarray:", DenseBoolArrayAttr.get([MyBool()]))
 
 
+# CHECK-LABEL: TEST: testDenseArrayAttrConstruction
+ at run
+def testDenseArrayAttrConstruction():
+    with Context(), Location.unknown():
+
+        def create_and_print(cls, x):
+            try:
+                darr = cls.get(x)
+                print(f"input: {x} ({type(x)}), result: {darr}")
+            except Exception as ex:
+                print(f"input: {x} ({type(x)}), error: {ex}")
+
+        # CHECK: input: [4, 2] (<class 'list'>),
+        # CHECK-SAME: result: array<i8: 4, 2>
+        create_and_print(DenseI8ArrayAttr, [4, 2])
+
+        # CHECK: input: [4, 2.0] (<class 'list'>),
+        # CHECK-SAME: error: get(): incompatible function arguments
+        create_and_print(DenseI8ArrayAttr, [4, 2.0])
+
+        # CHECK: input: [40000, 2] (<class 'list'>),
+        # CHECK-SAME: error: get(): incompatible function arguments
+        create_and_print(DenseI8ArrayAttr, [40000, 2])
+
+        # CHECK: input: range(0, 4) (<class 'range'>),
+        # CHECK-SAME: result: array<i8: 0, 1, 2, 3>
+        create_and_print(DenseI8ArrayAttr, range(4))
+
+        # CHECK: input: [IntegerAttr(4 : i64), IntegerAttr(2 : i64)] (<class 'list'>),
+        # CHECK-SAME: result: array<i8: 4, 2>
+        create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4, 2]])
+
+        # CHECK: input: [IntegerAttr(4000 : i64), IntegerAttr(2 : i64)] (<class 'list'>),
+        # CHECK-SAME: error: get(): incompatible function arguments
+        create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4000, 2]])
+
+        # CHECK: input: [IntegerAttr(4 : i64), FloatAttr(2.000000e+00 : f64)] (<class 'list'>),
+        # CHECK-SAME: error: get(): incompatible function arguments
+        create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4, 2.0]])
+
+        # CHECK: input: [IntegerAttr(4 : i8), IntegerAttr(2 : ui16)] (<class 'list'>),
+        # CHECK-SAME: result: array<i8: 4, 2>
+        create_and_print(
+            DenseI8ArrayAttr, [Attribute.parse(s) for s in ["4 : i8", "2 : ui16"]]
+        )
+
+        # CHECK: input: [FloatAttr(4.000000e+00 : f64), FloatAttr(2.000000e+00 : f64)] (<class 'list'>)
+        # CHECK-SAME: result: array<f32: 4.000000e+00, 2.000000e+00>
+        create_and_print(
+            DenseF32ArrayAttr, [Attribute.parse(f"{x}") for x in [4.0, 2.0]]
+        )
+
+        # CHECK: [BoolAttr(true), BoolAttr(false)] (<class 'list'>),
+        # CHECK-SAME: result: array<i1: true, false>
+        create_and_print(
+            DenseBoolArrayAttr, [Attribute.parse(f"{x}") for x in ["true", "false"]]
+        )
+
+
 # CHECK-LABEL: TEST: testDenseIntAttrGetItem
 @run
 def testDenseIntAttrGetItem():
@@ -620,7 +702,6 @@ def print_container_item(attr_asm):
 @run
 def testConcreteAttributesRoundTrip():
     with Context(), Location.unknown():
-
         # CHECK: FloatAttr(4.200000e+01 : f32)
         print(repr(Attribute.parse("42.0 : f32")))
 


        


More information about the Mlir-commits mailing list