[Mlir-commits] [mlir] e9db306 - [mlir][python] Support more types in IntegerAttr.value

Alex Zinenko llvmlistbot at llvm.org
Thu Feb 24 01:26:37 PST 2022


Author: rkayaith
Date: 2022-02-24T10:26:31+01:00
New Revision: e9db306dcd53f33b982d772793ffe7326d40c018

URL: https://github.com/llvm/llvm-project/commit/e9db306dcd53f33b982d772793ffe7326d40c018
DIFF: https://github.com/llvm/llvm-project/commit/e9db306dcd53f33b982d772793ffe7326d40c018.diff

LOG: [mlir][python] Support more types in IntegerAttr.value

Previously only accessing values for `index` and signless int types
would work; signed and unsigned ints would hit an assert in
`IntegerAttr::getInt`. This exposes `IntegerAttr::get{S,U}Int` to the C
API and calls the appropriate function from the python bindings.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D120194

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 973b7e99469c0..bb4431f7b3ef7 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -125,9 +125,17 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGet(MlirType type,
                                                     int64_t value);
 
 /// Returns the value stored in the given integer attribute, assuming the value
-/// fits into a 64-bit integer.
+/// is of signless type and fits into a signed 64-bit integer.
 MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr);
 
+/// Returns the value stored in the given integer attribute, assuming the value
+/// is of signed type and fits into a signed 64-bit integer.
+MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr);
+
+/// Returns the value stored in the given integer attribute, assuming the value
+/// is of unsigned type and fits into an unsigned 64-bit integer.
+MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 5d87641c379d8..bef3b95a2487a 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -258,8 +258,13 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
         "Gets an uniqued integer attribute associated to a type");
     c.def_property_readonly(
         "value",
-        [](PyIntegerAttribute &self) {
-          return mlirIntegerAttrGetValueInt(self);
+        [](PyIntegerAttribute &self) -> py::int_ {
+          MlirType type = mlirAttributeGetType(self);
+          if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+            return mlirIntegerAttrGetValueInt(self);
+          if (mlirIntegerTypeIsSigned(type))
+            return mlirIntegerAttrGetValueSInt(self);
+          return mlirIntegerAttrGetValueUInt(self);
         },
         "Returns the value of the integer attribute");
   }

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 7b718da88ceef..9ea277b746d61 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -129,6 +129,14 @@ int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
   return unwrap(attr).cast<IntegerAttr>().getInt();
 }
 
+int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
+  return unwrap(attr).cast<IntegerAttr>().getSInt();
+}
+
+uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
+  return unwrap(attr).cast<IntegerAttr>().getUInt();
+}
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 7ac7a19a579d3..c8d27398d6fc5 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -813,11 +813,21 @@ int printBuiltinAttributes(MlirContext ctx) {
   // CHECK: f64
 
   MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
+  MlirAttribute signedInteger =
+      mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx, 8), -1);
+  MlirAttribute unsignedInteger =
+      mlirIntegerAttrGet(mlirIntegerTypeUnsignedGet(ctx, 8), 255);
   if (!mlirAttributeIsAInteger(integer) ||
-      mlirIntegerAttrGetValueInt(integer) != 42)
+      mlirIntegerAttrGetValueInt(integer) != 42 ||
+      mlirIntegerAttrGetValueSInt(signedInteger) != -1 ||
+      mlirIntegerAttrGetValueUInt(unsignedInteger) != 255)
     return 2;
   mlirAttributeDump(integer);
+  mlirAttributeDump(signedInteger);
+  mlirAttributeDump(unsignedInteger);
   // CHECK: 42 : i32
+  // CHECK: -1 : si8
+  // CHECK: 255 : ui8
 
   MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
   if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 48f2d4b3df067..53d246b397528 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -189,11 +189,20 @@ def testFloatAttr():
 @run
 def testIntegerAttr():
   with Context() as ctx:
-    iattr = IntegerAttr(Attribute.parse("42"))
-    # CHECK: iattr value: 42
-    print("iattr value:", iattr.value)
-    # CHECK: iattr type: i64
-    print("iattr type:", iattr.type)
+    i_attr = IntegerAttr(Attribute.parse("42"))
+    # CHECK: i_attr value: 42
+    print("i_attr value:", i_attr.value)
+    # CHECK: i_attr type: i64
+    print("i_attr type:", i_attr.type)
+    si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
+    # CHECK: si_attr value: -1
+    print("si_attr value:", si_attr.value)
+    ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
+    # CHECK: ui_attr value: 255
+    print("ui_attr value:", ui_attr.value)
+    idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
+    # CHECK: idx_attr value: -1
+    print("idx_attr value:", idx_attr.value)
 
     # Test factory methods.
     # CHECK: default_get: 42 : i32


        


More information about the Mlir-commits mailing list