[Mlir-commits] [mlir] Handle large integers (> 64 bits) for the IntegerAttr C-API (PR #130539)
Ryan Buchner
llvmlistbot at llvm.org
Sun Mar 9 19:50:53 PDT 2025
https://github.com/bababuck created https://github.com/llvm/llvm-project/pull/130539
Address issue #128072.
@stellaraccident @teqdruid I created/used the `apint_interop_t` struct as suggested. Similar to APInt, I put a union member within the struct so that small values will be stored locally inside the struct.
For testing, expanded `mlir/test/python/ir/attributes.py` to include large integer cases (including the one mentioned in the issue. All `mlir/test` pass when running `llvm-lit` except one, see below.
One issue that needs to be resolved is that there are uses in the dialect tests when a signless IntegerAttr is created with a negative value. Should this be allowed, or should we only permit negative values to be used when creating signed IntegerAttrs?
>From 9a64fda4ee6ea3f9a203e89e955bea5ac002a2e5 Mon Sep 17 00:00:00 2001
From: bababuck <92571492+bababuck at users.noreply.github.com>
Date: Sun, 9 Mar 2025 12:57:29 -0700
Subject: [PATCH] Handle large integers (> 64 bits) for the IntegerAttr C-API
Fixes issue #128072.
Allows for arbitrarily sized integers to be requested via Python.
---
mlir/include/mlir-c/BuiltinAttributes.h | 16 +++++
mlir/lib/Bindings/Python/IRAttributes.cpp | 72 ++++++++++++++++++++---
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 34 +++++++++++
mlir/test/python/ir/attributes.py | 45 ++++++++++++++
4 files changed, 160 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 231eb83b5e269..b1348255bd6eb 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -157,6 +157,22 @@ MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
/// Returns the typeID of an Integer attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
+typedef struct {
+ size_t numbits;
+ union {
+ uint64_t *pVAL;
+ uint64_t VAL;
+ } data;
+} apint_interop_t;
+
+// Creates an APInt interop from an IntegerAttr
+MLIR_CAPI_EXPORTED int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+ apint_interop_t *interop);
+
+// Creates an integer attribute of the given type from an APInt interop
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirIntegerAttrFromInterop(MlirType type, apint_interop_t *interop);
+
//===----------------------------------------------------------------------===//
// Bool attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b4049bd7972d4..0892352299a8b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -456,8 +456,30 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type, value);
+ [](PyType &type, py::int_ value) {
+ apint_interop_t interop;
+ if (mlirTypeIsAIndex(type))
+ interop.numbits = 64;
+ else
+ interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+
+ py::object to_bytes = value.attr("to_bytes");
+ int numbytes = (interop.numbits + 7) / 8;
+ bool Signed =
+ mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+ py::bytes bytes_obj =
+ to_bytes(numbytes, "little", py::arg("signed") = Signed);
+ const char *data = ((std::string_view)bytes_obj).data();
+
+ if (interop.numbits <= 64) {
+ memcpy((char *)&(interop.data.VAL), data, numbytes);
+ } else {
+ int numdoublewords = (interop.numbits + 63) / 64;
+ interop.data.pVAL =
+ (uint64_t *)calloc(numdoublewords, sizeof(uint64_t));
+ memcpy((char *)interop.data.pVAL, data, numbytes);
+ }
+ MlirAttribute attr = mlirIntegerAttrFromInterop(type, &interop);
return PyIntegerAttribute(type.getContext(), attr);
},
py::arg("type"), py::arg("value"),
@@ -475,11 +497,47 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
private:
static py::int_ toPyInt(PyIntegerAttribute &self) {
MlirType type = mlirAttributeGetType(self);
- if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
- return mlirIntegerAttrGetValueInt(self);
- if (mlirIntegerTypeIsSigned(type))
- return mlirIntegerAttrGetValueSInt(self);
- return mlirIntegerAttrGetValueUInt(self);
+ apint_interop_t interop;
+ if (mlirTypeIsAIndex(type))
+ interop.numbits = 64;
+ else
+ interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+ if (interop.numbits > 64) {
+ size_t required_doublewords = (interop.numbits + 63) / 64;
+ interop.data.pVAL =
+ (uint64_t *)calloc(required_doublewords, sizeof(uint64_t));
+ }
+ mlirIntegerAttrGetValueInterop(self, &interop);
+
+ // Need to sign extend the last byte for conversion to py::bytes
+ bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+ if (Signed) {
+ size_t last_doubleword = (interop.numbits - 1) / 64;
+ size_t last_bit = interop.numbits - 1 - (64 * last_doubleword);
+ uint64_t sext_mask = -1 << last_bit;
+
+ if (interop.numbits > 64) {
+ if ((interop.data.pVAL[last_doubleword] >> last_bit) & 1) {
+ interop.data.pVAL[last_doubleword] |= sext_mask;
+ }
+ } else {
+ if ((interop.data.VAL >> last_bit) & 1) {
+ interop.data.VAL |= sext_mask;
+ }
+ }
+ }
+
+ py::int_ int_obj;
+ py::object from_bytes = int_obj.attr("from_bytes");
+ size_t numbytes = (interop.numbits + 7) / 8;
+ py::bytes bytes_obj;
+ if (interop.numbits > 64) {
+ bytes_obj = py::bytes((const char *)interop.data.pVAL, numbytes);
+ } else {
+ bytes_obj = py::bytes((const char *)&interop.data.VAL, numbytes);
+ }
+ int_obj = from_bytes(bytes_obj, "little", py::arg("signed") = Signed);
+ return int_obj;
}
};
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 726af884668b2..a86f0eacf304f 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -160,6 +160,40 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
}
+int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+ apint_interop_t *interop) {
+ size_t needed_bit_width =
+ llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getBitWidth();
+ if (interop->numbits < needed_bit_width) {
+ interop->numbits = needed_bit_width;
+ return 1;
+ }
+ if (interop->numbits <= 64) {
+ interop->data.VAL =
+ llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getRawData()[0];
+ return 0;
+ }
+ int memcpy_bytes = (interop->numbits + 7) / 8;
+ memcpy((void *)interop->data.pVAL,
+ (const void *)llvm::cast<IntegerAttr>(unwrap(attr))
+ .getValue()
+ .getRawData(),
+ memcpy_bytes);
+ return 0;
+}
+
+MlirAttribute mlirIntegerAttrFromInterop(MlirType type,
+ apint_interop_t *interop) {
+ if (interop->numbits <= 64) {
+ return wrap(IntegerAttr::get(unwrap(type), interop->data.VAL));
+ }
+ APInt apInt(interop->numbits,
+ llvm::ArrayRef<uint64_t>(interop->data.pVAL,
+ (interop->numbits + 63) / 64));
+ IntegerAttr value = IntegerAttr::get(unwrap(type), apInt);
+ return wrap(value);
+}
+
MlirTypeID mlirIntegerAttrGetTypeID(void) {
return wrap(IntegerAttr::getTypeID());
}
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 4b475db634645..0369c00839589 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -221,6 +221,51 @@ def testIntegerAttr():
print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42))
+ at run
+def testLargeIntegerAttr():
+ with Context() as ctx:
+ max_positive_64_val = 0x7fffffffffffffff
+ max_positive_64 = IntegerAttr.get(IntegerType.get_signed(64), max_positive_64_val)
+ # CHECK: max_positive_64: 9223372036854775807 : si64
+ print("max_positive_64:", max_positive_64)
+ assert(int(max_positive_64) == max_positive_64_val)
+
+ neg_one_64_val = -1
+ neg_one_64 = IntegerAttr.get(IntegerType.get_signed(64), neg_one_64_val)
+ # CHECK: neg_one_64: -1 : si64
+ print("neg_one_64:", neg_one_64)
+ assert(int(neg_one_64) == neg_one_64_val)
+
+ max_unsigned_64_val = 0xffffffffffffffff
+ max_unsigned_64 = IntegerAttr.get(IntegerType.get_signless(64), max_unsigned_64_val)
+ # CHECK: max_unsigned_64: -1 : i64
+ print("max_unsigned_64:", max_unsigned_64)
+ assert(int(max_unsigned_64) == max_unsigned_64_val)
+
+ random_64_val = 0x0123456789ABCDEF
+ random_64 = IntegerAttr.get(IntegerType.get_signless(64), random_64_val)
+ # CHECK: random_64: 81985529216486895 : i64
+ print("random_64:", random_64)
+ assert(int(random_64) == random_64_val)
+
+ max_unsigned_65_val = 0x1FFFFFFFFFFFFFFFF
+ max_unsigned_65 = IntegerAttr.get(IntegerType.get_unsigned(65), max_unsigned_65_val)
+ # CHECK: max_unsigned_65: 36893488147419103231 : ui65
+ print("max_unsigned_65:", max_unsigned_65)
+ assert(int(max_unsigned_65) == max_unsigned_65_val)
+
+ random_128_val = 0x0123456789ABCDEF0123456789ABCDEF
+ random_128 = IntegerAttr.get(IntegerType.get_signless(128), random_128_val)
+ # CHECK: random_128: 1512366075204170929049582354406559215 : i128
+ print("random_128:", random_128)
+ assert(int(random_128) == random_128_val)
+
+ random_92_val = 0x9ABCDEF0123456789ABCDEF
+ random_92 = IntegerAttr.get(IntegerType.get_signless(92), random_92_val)
+ # CHECK: random_92: -1958696259612506469130580497 : i92
+ print("random_92:", random_92)
+ assert(int(random_92) == random_92_val)
+
# CHECK-LABEL: TEST: testBoolAttr
@run
def testBoolAttr():
More information about the Mlir-commits
mailing list