[Mlir-commits] [mlir] Handle large integers (> 64 bits) for the IntegerAttr C-API (PR #130539)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 9 19:51:47 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ryan Buchner (bababuck)

<details>
<summary>Changes</summary>

Address issue #<!-- -->128072.

@<!-- -->stellaraccident @<!-- -->teqdruid I created/used the `apint_interop_t` struct as suggested. Similar to APInt, I put a union member within the struct so that small values will be stored locally inside the struct.

For testing, expanded `mlir/test/python/ir/attributes.py` to include large integer cases (including the one mentioned in the issue. All `mlir/test` pass when running `llvm-lit` except one, see below.

One issue that needs to be resolved is that there are uses in the dialect tests when a signless IntegerAttr is created with a negative value. Should this be allowed, or should we only permit negative values to be used when creating signed IntegerAttrs?

---
Full diff: https://github.com/llvm/llvm-project/pull/130539.diff


4 Files Affected:

- (modified) mlir/include/mlir-c/BuiltinAttributes.h (+16) 
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+65-7) 
- (modified) mlir/lib/CAPI/IR/BuiltinAttributes.cpp (+34) 
- (modified) mlir/test/python/ir/attributes.py (+45) 


``````````diff
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 231eb83b5e269..b1348255bd6eb 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -157,6 +157,22 @@ MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
 /// Returns the typeID of an Integer attribute.
 MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
 
+typedef struct {
+  size_t numbits;
+  union {
+    uint64_t *pVAL;
+    uint64_t VAL;
+  } data;
+} apint_interop_t;
+
+// Creates an APInt interop from an IntegerAttr
+MLIR_CAPI_EXPORTED int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+                                                      apint_interop_t *interop);
+
+// Creates an integer attribute of the given type from an APInt interop
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirIntegerAttrFromInterop(MlirType type, apint_interop_t *interop);
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b4049bd7972d4..0892352299a8b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -456,8 +456,30 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyType &type, int64_t value) {
-          MlirAttribute attr = mlirIntegerAttrGet(type, value);
+        [](PyType &type, py::int_ value) {
+          apint_interop_t interop;
+          if (mlirTypeIsAIndex(type))
+            interop.numbits = 64;
+          else
+            interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+
+          py::object to_bytes = value.attr("to_bytes");
+          int numbytes = (interop.numbits + 7) / 8;
+          bool Signed =
+              mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+          py::bytes bytes_obj =
+              to_bytes(numbytes, "little", py::arg("signed") = Signed);
+          const char *data = ((std::string_view)bytes_obj).data();
+
+          if (interop.numbits <= 64) {
+            memcpy((char *)&(interop.data.VAL), data, numbytes);
+          } else {
+            int numdoublewords = (interop.numbits + 63) / 64;
+            interop.data.pVAL =
+                (uint64_t *)calloc(numdoublewords, sizeof(uint64_t));
+            memcpy((char *)interop.data.pVAL, data, numbytes);
+          }
+          MlirAttribute attr = mlirIntegerAttrFromInterop(type, &interop);
           return PyIntegerAttribute(type.getContext(), attr);
         },
         py::arg("type"), py::arg("value"),
@@ -475,11 +497,47 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
 private:
   static py::int_ toPyInt(PyIntegerAttribute &self) {
     MlirType type = mlirAttributeGetType(self);
-    if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
-      return mlirIntegerAttrGetValueInt(self);
-    if (mlirIntegerTypeIsSigned(type))
-      return mlirIntegerAttrGetValueSInt(self);
-    return mlirIntegerAttrGetValueUInt(self);
+    apint_interop_t interop;
+    if (mlirTypeIsAIndex(type))
+      interop.numbits = 64;
+    else
+      interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+    if (interop.numbits > 64) {
+      size_t required_doublewords = (interop.numbits + 63) / 64;
+      interop.data.pVAL =
+          (uint64_t *)calloc(required_doublewords, sizeof(uint64_t));
+    }
+    mlirIntegerAttrGetValueInterop(self, &interop);
+
+    // Need to sign extend the last byte for conversion to py::bytes
+    bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+    if (Signed) {
+      size_t last_doubleword = (interop.numbits - 1) / 64;
+      size_t last_bit = interop.numbits - 1 - (64 * last_doubleword);
+      uint64_t sext_mask = -1 << last_bit;
+
+      if (interop.numbits > 64) {
+        if ((interop.data.pVAL[last_doubleword] >> last_bit) & 1) {
+          interop.data.pVAL[last_doubleword] |= sext_mask;
+        }
+      } else {
+        if ((interop.data.VAL >> last_bit) & 1) {
+          interop.data.VAL |= sext_mask;
+        }
+      }
+    }
+
+    py::int_ int_obj;
+    py::object from_bytes = int_obj.attr("from_bytes");
+    size_t numbytes = (interop.numbits + 7) / 8;
+    py::bytes bytes_obj;
+    if (interop.numbits > 64) {
+      bytes_obj = py::bytes((const char *)interop.data.pVAL, numbytes);
+    } else {
+      bytes_obj = py::bytes((const char *)&interop.data.VAL, numbytes);
+    }
+    int_obj = from_bytes(bytes_obj, "little", py::arg("signed") = Signed);
+    return int_obj;
   }
 };
 
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 726af884668b2..a86f0eacf304f 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -160,6 +160,40 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
   return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
 }
 
+int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+                                   apint_interop_t *interop) {
+  size_t needed_bit_width =
+      llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getBitWidth();
+  if (interop->numbits < needed_bit_width) {
+    interop->numbits = needed_bit_width;
+    return 1;
+  }
+  if (interop->numbits <= 64) {
+    interop->data.VAL =
+        llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getRawData()[0];
+    return 0;
+  }
+  int memcpy_bytes = (interop->numbits + 7) / 8;
+  memcpy((void *)interop->data.pVAL,
+         (const void *)llvm::cast<IntegerAttr>(unwrap(attr))
+             .getValue()
+             .getRawData(),
+         memcpy_bytes);
+  return 0;
+}
+
+MlirAttribute mlirIntegerAttrFromInterop(MlirType type,
+                                         apint_interop_t *interop) {
+  if (interop->numbits <= 64) {
+    return wrap(IntegerAttr::get(unwrap(type), interop->data.VAL));
+  }
+  APInt apInt(interop->numbits,
+              llvm::ArrayRef<uint64_t>(interop->data.pVAL,
+                                       (interop->numbits + 63) / 64));
+  IntegerAttr value = IntegerAttr::get(unwrap(type), apInt);
+  return wrap(value);
+}
+
 MlirTypeID mlirIntegerAttrGetTypeID(void) {
   return wrap(IntegerAttr::getTypeID());
 }
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 4b475db634645..0369c00839589 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -221,6 +221,51 @@ def testIntegerAttr():
         print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42))
 
 
+ at run
+def testLargeIntegerAttr():
+    with Context() as ctx:
+        max_positive_64_val = 0x7fffffffffffffff
+        max_positive_64 = IntegerAttr.get(IntegerType.get_signed(64), max_positive_64_val)
+        # CHECK: max_positive_64: 9223372036854775807 : si64
+        print("max_positive_64:", max_positive_64)
+        assert(int(max_positive_64) == max_positive_64_val)
+
+        neg_one_64_val = -1
+        neg_one_64 = IntegerAttr.get(IntegerType.get_signed(64), neg_one_64_val)
+        # CHECK: neg_one_64: -1 : si64
+        print("neg_one_64:", neg_one_64)
+        assert(int(neg_one_64) == neg_one_64_val)
+
+        max_unsigned_64_val = 0xffffffffffffffff
+        max_unsigned_64 = IntegerAttr.get(IntegerType.get_signless(64), max_unsigned_64_val)
+        # CHECK: max_unsigned_64: -1 : i64
+        print("max_unsigned_64:", max_unsigned_64)
+        assert(int(max_unsigned_64) == max_unsigned_64_val)
+
+        random_64_val = 0x0123456789ABCDEF
+        random_64 = IntegerAttr.get(IntegerType.get_signless(64), random_64_val)
+        # CHECK: random_64: 81985529216486895 : i64
+        print("random_64:", random_64)
+        assert(int(random_64) == random_64_val)
+
+        max_unsigned_65_val = 0x1FFFFFFFFFFFFFFFF
+        max_unsigned_65 = IntegerAttr.get(IntegerType.get_unsigned(65), max_unsigned_65_val)
+        # CHECK: max_unsigned_65: 36893488147419103231 : ui65
+        print("max_unsigned_65:", max_unsigned_65)
+        assert(int(max_unsigned_65) == max_unsigned_65_val)
+
+        random_128_val = 0x0123456789ABCDEF0123456789ABCDEF
+        random_128 = IntegerAttr.get(IntegerType.get_signless(128), random_128_val)
+        # CHECK: random_128: 1512366075204170929049582354406559215 : i128
+        print("random_128:", random_128)
+        assert(int(random_128) == random_128_val)
+
+        random_92_val = 0x9ABCDEF0123456789ABCDEF
+        random_92 = IntegerAttr.get(IntegerType.get_signless(92), random_92_val)
+        # CHECK: random_92: -1958696259612506469130580497 : i92
+        print("random_92:", random_92)
+        assert(int(random_92) == random_92_val)
+
 # CHECK-LABEL: TEST: testBoolAttr
 @run
 def testBoolAttr():

``````````

</details>


https://github.com/llvm/llvm-project/pull/130539


More information about the Mlir-commits mailing list