[Mlir-commits] [mlir] Handle large integers (> 64 bits) for the IntegerAttr C-API (PR #130539)
Ryan Buchner
llvmlistbot at llvm.org
Thu Mar 20 06:29:28 PDT 2025
https://github.com/bababuck updated https://github.com/llvm/llvm-project/pull/130539
>From 4ba791e8076a6ae6a6ea9fed99c4fbdbef01fe9d Mon Sep 17 00:00:00 2001
From: bababuck <92571492+bababuck at users.noreply.github.com>
Date: Sun, 9 Mar 2025 12:57:29 -0700
Subject: [PATCH] Handle large integers (> 64 bits) for the IntegerAttr C-API
Fixes issue #128072.
Allows for arbitrarily sized integers to be requested via Python.
---
mlir/include/mlir-c/BuiltinAttributes.h | 19 ++++++
mlir/lib/Bindings/Python/IRAttributes.cpp | 74 ++++++++++++++++++++---
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 34 +++++++++++
mlir/test/python/ir/attributes.py | 45 ++++++++++++++
4 files changed, 165 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 1d0edf9ea809d..29053df0a55ae 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -158,6 +158,25 @@ MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
/// Returns the typeID of an Integer attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
+// Used to create large IntegerAttr's (>64 bits) via the CAPI
+// See
+// https://github.com/llvm/llvm-project/issues/128072#issuecomment-2672767777
+typedef struct {
+ size_t numbits;
+ union {
+ uint64_t *pVAL;
+ uint64_t VAL;
+ } data;
+} apint_interop_t;
+
+// Creates an APInt interop from an IntegerAttr
+MLIR_CAPI_EXPORTED int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+ apint_interop_t *interop);
+
+// Creates an integer attribute of the given type from an APInt interop
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirIntegerAttrFromInterop(MlirType type, apint_interop_t *interop);
+
//===----------------------------------------------------------------------===//
// Bool attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 12725a0ed0939..d16380727d251 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -601,8 +601,31 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type, value);
+ [](PyType &type, py::int_ value) {
+ apint_interop_t interop;
+ if (mlirTypeIsAIndex(type))
+ interop.numbits = 64;
+ else
+ interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+
+ py::object to_bytes = value.attr("to_bytes");
+ int numbytes = (interop.numbits + 7) / 8;
+ bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+ py::bytes bytes_obj =
+ to_bytes(numbytes, "little", py::arg("signed") = Signed);
+ const char *data = bytes_obj.data();
+
+ if (interop.numbits <= 64) {
+ memcpy((char *)&(interop.data.VAL), data, numbytes);
+ } else {
+ int numdoublewords = (interop.numbits + 63) / 64;
+ interop.data.pVAL =
+ (uint64_t *)malloc(numdoublewords, sizeof(uint64_t));
+ memcpy((char *)interop.data.pVAL, data, numbytes);
+ }
+ MlirAttribute attr = mlirIntegerAttrFromInterop(type, &interop);
+ if (interop.numbits <= 64)
+ free(interop.data.pVAL);
return PyIntegerAttribute(type.getContext(), attr);
},
nb::arg("type"), nb::arg("value"),
@@ -620,11 +643,48 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
private:
static int64_t toPyInt(PyIntegerAttribute &self) {
MlirType type = mlirAttributeGetType(self);
- if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
- return mlirIntegerAttrGetValueInt(self);
- if (mlirIntegerTypeIsSigned(type))
- return mlirIntegerAttrGetValueSInt(self);
- return mlirIntegerAttrGetValueUInt(self);
+ apint_interop_t interop;
+ if (mlirTypeIsAIndex(type))
+ interop.numbits = 64;
+ else
+ interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
+ if (interop.numbits > 64) {
+ size_t required_doublewords = (interop.numbits + 63) / 64;
+ interop.data.pVAL =
+ (uint64_t *)malloc(required_doublewords, sizeof(uint64_t));
+ }
+ mlirIntegerAttrGetValueInterop(self, &interop);
+
+ // Need to sign extend the last byte for conversion to py::bytes
+ bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
+ if (Signed) {
+ size_t last_doubleword = (interop.numbits - 1) / 64;
+ size_t last_bit = interop.numbits - 1 - (64 * last_doubleword);
+ uint64_t sext_mask = -1 << last_bit;
+
+ if (interop.numbits > 64) {
+ if ((interop.data.pVAL[last_doubleword] >> last_bit) & 1) {
+ interop.data.pVAL[last_doubleword] |= sext_mask;
+ }
+ } else {
+ if ((interop.data.VAL >> last_bit) & 1) {
+ interop.data.VAL |= sext_mask;
+ }
+ }
+ }
+
+ py::int_ int_obj;
+ py::object from_bytes = int_obj.attr("from_bytes");
+ size_t numbytes = (interop.numbits + 7) / 8;
+ py::bytes bytes_obj;
+ if (interop.numbits > 64) {
+ bytes_obj = py::bytes((const char *)interop.data.pVAL, numbytes);
+ free(interop.data.pVAL);
+ } else {
+ bytes_obj = py::bytes((const char *)&interop.data.VAL, numbytes);
+ }
+ int_obj = from_bytes(bytes_obj, "little", py::arg("signed") = Signed);
+ return int_obj;
}
};
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 8d57ab6b59e79..0adb9fc277daa 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -161,6 +161,40 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
}
+int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
+ apint_interop_t *interop) {
+ size_t needed_bit_width =
+ llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getBitWidth();
+ if (interop->numbits < needed_bit_width) {
+ interop->numbits = needed_bit_width;
+ return 1;
+ }
+ if (interop->numbits <= 64) {
+ interop->data.VAL =
+ llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getRawData()[0];
+ return 0;
+ }
+ int memcpy_bytes = (interop->numbits + 7) / 8;
+ memcpy((void *)interop->data.pVAL,
+ (const void *)llvm::cast<IntegerAttr>(unwrap(attr))
+ .getValue()
+ .getRawData(),
+ memcpy_bytes);
+ return 0;
+}
+
+MlirAttribute mlirIntegerAttrFromInterop(MlirType type,
+ apint_interop_t *interop) {
+ if (interop->numbits <= 64) {
+ return wrap(IntegerAttr::get(unwrap(type), interop->data.VAL));
+ }
+ APInt apInt(interop->numbits,
+ llvm::ArrayRef<uint64_t>(interop->data.pVAL,
+ (interop->numbits + 63) / 64));
+ IntegerAttr value = IntegerAttr::get(unwrap(type), apInt);
+ return wrap(value);
+}
+
MlirTypeID mlirIntegerAttrGetTypeID(void) {
return wrap(IntegerAttr::getTypeID());
}
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 2f3c4460d3f59..da45e7b4b58bf 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -239,6 +239,51 @@ def testIntegerAttr():
print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42))
+ at run
+def testLargeIntegerAttr():
+ with Context() as ctx:
+ max_positive_64_val = 0x7fffffffffffffff
+ max_positive_64 = IntegerAttr.get(IntegerType.get_signed(64), max_positive_64_val)
+ # CHECK: max_positive_64: 9223372036854775807 : si64
+ print("max_positive_64:", max_positive_64)
+ assert(int(max_positive_64) == max_positive_64_val)
+
+ neg_one_64_val = -1
+ neg_one_64 = IntegerAttr.get(IntegerType.get_signed(64), neg_one_64_val)
+ # CHECK: neg_one_64: -1 : si64
+ print("neg_one_64:", neg_one_64)
+ assert(int(neg_one_64) == neg_one_64_val)
+
+ max_unsigned_64_val = 0xffffffffffffffff
+ max_unsigned_64 = IntegerAttr.get(IntegerType.get_signless(64), max_unsigned_64_val)
+ # CHECK: max_unsigned_64: -1 : i64
+ print("max_unsigned_64:", max_unsigned_64)
+ assert(int(max_unsigned_64) == max_unsigned_64_val)
+
+ random_64_val = 0x0123456789ABCDEF
+ random_64 = IntegerAttr.get(IntegerType.get_signless(64), random_64_val)
+ # CHECK: random_64: 81985529216486895 : i64
+ print("random_64:", random_64)
+ assert(int(random_64) == random_64_val)
+
+ max_unsigned_65_val = 0x1FFFFFFFFFFFFFFFF
+ max_unsigned_65 = IntegerAttr.get(IntegerType.get_unsigned(65), max_unsigned_65_val)
+ # CHECK: max_unsigned_65: 36893488147419103231 : ui65
+ print("max_unsigned_65:", max_unsigned_65)
+ assert(int(max_unsigned_65) == max_unsigned_65_val)
+
+ random_128_val = 0x0123456789ABCDEF0123456789ABCDEF
+ random_128 = IntegerAttr.get(IntegerType.get_signless(128), random_128_val)
+ # CHECK: random_128: 1512366075204170929049582354406559215 : i128
+ print("random_128:", random_128)
+ assert(int(random_128) == random_128_val)
+
+ random_92_val = 0x9ABCDEF0123456789ABCDEF
+ random_92 = IntegerAttr.get(IntegerType.get_signless(92), random_92_val)
+ # CHECK: random_92: -1958696259612506469130580497 : i92
+ print("random_92:", random_92)
+ assert(int(random_92) == random_92_val)
+
# CHECK-LABEL: TEST: testBoolAttr
@run
def testBoolAttr():
More information about the Mlir-commits
mailing list