[Mlir-commits] [mlir] Handle large integers (> 64 bits) for the IntegerAttr C-API (PR #130539)

Ryan Buchner llvmlistbot at llvm.org
Thu Mar 20 06:29:28 PDT 2025


https://github.com/bababuck updated https://github.com/llvm/llvm-project/pull/130539

>From 4ba791e8076a6ae6a6ea9fed99c4fbdbef01fe9d Mon Sep 17 00:00:00 2001
From: bababuck <92571492+bababuck at users.noreply.github.com>
Date: Sun, 9 Mar 2025 12:57:29 -0700
Subject: [PATCH] Handle large integers (> 64 bits) for the IntegerAttr C-API

Fixes issue #128072.

Allows for arbitrarily sized integers to be requested via Python.
---
 mlir/include/mlir-c/BuiltinAttributes.h   | 19 ++++++
 mlir/lib/Bindings/Python/IRAttributes.cpp | 74 ++++++++++++++++++++---
 mlir/lib/CAPI/IR/BuiltinAttributes.cpp    | 34 +++++++++++
 mlir/test/python/ir/attributes.py         | 45 ++++++++++++++
 4 files changed, 165 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 1d0edf9ea809d..29053df0a55ae 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -158,6 +158,25 @@ MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
 /// Returns the typeID of an Integer attribute.
 MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
 
+// Used to create large IntegerAttr's (>64 bits) via the CAPI
+// See
+// https://github.com/llvm/llvm-project/issues/128072#issuecomment-2672767777
+typedef struct {
+  size_t numbits;
+  union {
+    uint64_t *pVAL;
+    uint64_t VAL;
+  } data;
+} apint_interop_t;
+
+// Creates an APInt interop from an IntegerAttr
+MLIR_CAPI_EXPORTED int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+                                                      apint_interop_t *interop);
+
+// Creates an integer attribute of the given type from an APInt interop
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirIntegerAttrFromInterop(MlirType type, apint_interop_t *interop);
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 12725a0ed0939..d16380727d251 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -601,8 +601,31 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyType &type, int64_t value) {
-          MlirAttribute attr = mlirIntegerAttrGet(type, value);
+        [](PyType &type, py::int_ value) {
+          apint_interop_t interop;
+          if (mlirTypeIsAIndex(type))
+            interop.numbits = 64;
+          else
+            interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+
+          py::object to_bytes = value.attr("to_bytes");
+          int numbytes = (interop.numbits + 7) / 8;
+          bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+          py::bytes bytes_obj =
+              to_bytes(numbytes, "little", py::arg("signed") = Signed);
+          const char *data = bytes_obj.data();
+
+          if (interop.numbits <= 64) {
+            memcpy((char *)&(interop.data.VAL), data, numbytes);
+          } else {
+            int numdoublewords = (interop.numbits + 63) / 64;
+            interop.data.pVAL =
+                (uint64_t *)malloc(numdoublewords, sizeof(uint64_t));
+            memcpy((char *)interop.data.pVAL, data, numbytes);
+          }
+          MlirAttribute attr = mlirIntegerAttrFromInterop(type, &interop);
+          if (interop.numbits <= 64)
+            free(interop.data.pVAL);
           return PyIntegerAttribute(type.getContext(), attr);
         },
         nb::arg("type"), nb::arg("value"),
@@ -620,11 +643,48 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
 private:
   static int64_t toPyInt(PyIntegerAttribute &self) {
     MlirType type = mlirAttributeGetType(self);
-    if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
-      return mlirIntegerAttrGetValueInt(self);
-    if (mlirIntegerTypeIsSigned(type))
-      return mlirIntegerAttrGetValueSInt(self);
-    return mlirIntegerAttrGetValueUInt(self);
+    apint_interop_t interop;
+    if (mlirTypeIsAIndex(type))
+      interop.numbits = 64;
+    else
+      interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+    if (interop.numbits > 64) {
+      size_t required_doublewords = (interop.numbits + 63) / 64;
+      interop.data.pVAL =
+          (uint64_t *)malloc(required_doublewords, sizeof(uint64_t));
+    }
+    mlirIntegerAttrGetValueInterop(self, &interop);
+
+    // Need to sign extend the last byte for conversion to py::bytes
+    bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+    if (Signed) {
+      size_t last_doubleword = (interop.numbits - 1) / 64;
+      size_t last_bit = interop.numbits - 1 - (64 * last_doubleword);
+      uint64_t sext_mask = -1 << last_bit;
+
+      if (interop.numbits > 64) {
+        if ((interop.data.pVAL[last_doubleword] >> last_bit) & 1) {
+          interop.data.pVAL[last_doubleword] |= sext_mask;
+        }
+      } else {
+        if ((interop.data.VAL >> last_bit) & 1) {
+          interop.data.VAL |= sext_mask;
+        }
+      }
+    }
+
+    py::int_ int_obj;
+    py::object from_bytes = int_obj.attr("from_bytes");
+    size_t numbytes = (interop.numbits + 7) / 8;
+    py::bytes bytes_obj;
+    if (interop.numbits > 64) {
+      bytes_obj = py::bytes((const char *)interop.data.pVAL, numbytes);
+      free(interop.data.pVAL);
+    } else {
+      bytes_obj = py::bytes((const char *)&interop.data.VAL, numbytes);
+    }
+    int_obj = from_bytes(bytes_obj, "little", py::arg("signed") = Signed);
+    return int_obj;
   }
 };
 
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 8d57ab6b59e79..0adb9fc277daa 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -161,6 +161,40 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
   return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
 }
 
+int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+                                   apint_interop_t *interop) {
+  size_t needed_bit_width =
+      llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getBitWidth();
+  if (interop->numbits < needed_bit_width) {
+    interop->numbits = needed_bit_width;
+    return 1;
+  }
+  if (interop->numbits <= 64) {
+    interop->data.VAL =
+        llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getRawData()[0];
+    return 0;
+  }
+  int memcpy_bytes = (interop->numbits + 7) / 8;
+  memcpy((void *)interop->data.pVAL,
+         (const void *)llvm::cast<IntegerAttr>(unwrap(attr))
+             .getValue()
+             .getRawData(),
+         memcpy_bytes);
+  return 0;
+}
+
+MlirAttribute mlirIntegerAttrFromInterop(MlirType type,
+                                         apint_interop_t *interop) {
+  if (interop->numbits <= 64) {
+    return wrap(IntegerAttr::get(unwrap(type), interop->data.VAL));
+  }
+  APInt apInt(interop->numbits,
+              llvm::ArrayRef<uint64_t>(interop->data.pVAL,
+                                       (interop->numbits + 63) / 64));
+  IntegerAttr value = IntegerAttr::get(unwrap(type), apInt);
+  return wrap(value);
+}
+
 MlirTypeID mlirIntegerAttrGetTypeID(void) {
   return wrap(IntegerAttr::getTypeID());
 }
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 2f3c4460d3f59..da45e7b4b58bf 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -239,6 +239,51 @@ def testIntegerAttr():
         print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42))
 
 
+ at run
+def testLargeIntegerAttr():
+    with Context() as ctx:
+        max_positive_64_val = 0x7fffffffffffffff
+        max_positive_64 = IntegerAttr.get(IntegerType.get_signed(64), max_positive_64_val)
+        # CHECK: max_positive_64: 9223372036854775807 : si64
+        print("max_positive_64:", max_positive_64)
+        assert(int(max_positive_64) == max_positive_64_val)
+
+        neg_one_64_val = -1
+        neg_one_64 = IntegerAttr.get(IntegerType.get_signed(64), neg_one_64_val)
+        # CHECK: neg_one_64: -1 : si64
+        print("neg_one_64:", neg_one_64)
+        assert(int(neg_one_64) == neg_one_64_val)
+
+        max_unsigned_64_val = 0xffffffffffffffff
+        max_unsigned_64 = IntegerAttr.get(IntegerType.get_signless(64), max_unsigned_64_val)
+        # CHECK: max_unsigned_64: -1 : i64
+        print("max_unsigned_64:", max_unsigned_64)
+        assert(int(max_unsigned_64) == max_unsigned_64_val)
+
+        random_64_val = 0x0123456789ABCDEF
+        random_64 = IntegerAttr.get(IntegerType.get_signless(64), random_64_val)
+        # CHECK: random_64: 81985529216486895 : i64
+        print("random_64:", random_64)
+        assert(int(random_64) == random_64_val)
+
+        max_unsigned_65_val = 0x1FFFFFFFFFFFFFFFF
+        max_unsigned_65 = IntegerAttr.get(IntegerType.get_unsigned(65), max_unsigned_65_val)
+        # CHECK: max_unsigned_65: 36893488147419103231 : ui65
+        print("max_unsigned_65:", max_unsigned_65)
+        assert(int(max_unsigned_65) == max_unsigned_65_val)
+
+        random_128_val = 0x0123456789ABCDEF0123456789ABCDEF
+        random_128 = IntegerAttr.get(IntegerType.get_signless(128), random_128_val)
+        # CHECK: random_128: 1512366075204170929049582354406559215 : i128
+        print("random_128:", random_128)
+        assert(int(random_128) == random_128_val)
+
+        random_92_val = 0x9ABCDEF0123456789ABCDEF
+        random_92 = IntegerAttr.get(IntegerType.get_signless(92), random_92_val)
+        # CHECK: random_92: -1958696259612506469130580497 : i92
+        print("random_92:", random_92)
+        assert(int(random_92) == random_92_val)
+
 # CHECK-LABEL: TEST: testBoolAttr
 @run
 def testBoolAttr():



More information about the Mlir-commits mailing list