[Mlir-commits] [mlir] ac88f7b - [mlir][python] Support Arbitrary Precision Integers in MLIR C API and Python Bindings (#177733)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 24 23:05:08 PST 2026


Author: Ryan Kim
Date: 2026-01-24T23:05:03-08:00
New Revision: ac88f7bcd41a358a7d47946e2896b03c1aade44f

URL: https://github.com/llvm/llvm-project/commit/ac88f7bcd41a358a7d47946e2896b03c1aade44f
DIFF: https://github.com/llvm/llvm-project/commit/ac88f7bcd41a358a7d47946e2896b03c1aade44f.diff

LOG: [mlir][python] Support Arbitrary Precision Integers in MLIR C API and Python Bindings (#177733)

This PR extends the MLIR C API and Python bindings to support
**arbitrary-precision integers (`APInt`)**, overcoming the previous
limitation where `IntegerAttr` values were restricted to 64 bits.

Cryptographic applications often require integer types much larger than
standard machine words (e.g., the 256-bit modulus for the BN254 curve).
Previously, attempting to bind these values resulted in truncation or
errors. This PR exposes the underlying word-based `APInt` structure via
the C API and updates the Python bindings to seamlessly handle Python's
arbitrary-precision integers.

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/include/mlir/Bindings/Python/IRAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index eab732365f6b8..69a50942e8ee6 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -165,6 +165,32 @@ MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr);
 /// is of unsigned type and fits into an unsigned 64-bit integer.
 MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
 
+/// Returns the bit width of the integer attribute's underlying APInt value.
+/// This is useful for determining the size of the integer, especially for
+/// values larger than 64 bits.
+MLIR_CAPI_EXPORTED unsigned mlirIntegerAttrGetValueBitWidth(MlirAttribute attr);
+
+/// Returns the number of 64-bit words that make up the integer attribute's
+/// underlying APInt value. For integers <= 64 bits, this returns 1.
+MLIR_CAPI_EXPORTED unsigned mlirIntegerAttrGetValueNumWords(MlirAttribute attr);
+
+/// Copies the 64-bit words making up the integer attribute's APInt value into
+/// the provided buffer. The buffer must have space for at least
+/// mlirIntegerAttrGetValueNumWords(attr) elements. Words are stored in
+/// little-endian order (least significant word first). The sign information
+/// is not encoded in the words themselves; use the type's signedness to
+/// interpret the value correctly.
+MLIR_CAPI_EXPORTED void mlirIntegerAttrGetValueWords(MlirAttribute attr,
+                                                     uint64_t *words);
+
+/// Creates an integer attribute of the given type from an array of 64-bit
+/// words. This is useful for creating integer attributes with values with
+/// widths larger than 64 bits. Words are in little-endian order (least
+/// significant word first). The number of words must match the bit width of the
+/// type: numWords = ceil(bitWidth / 64).
+MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGetFromWords(
+    MlirType type, unsigned numWords, const uint64_t *words);
+
 /// Returns the typeID of an Integer attribute.
 MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
 

diff  --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 6175710d76dd0..5ff9afd0875f1 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -341,7 +341,7 @@ class MLIR_PYTHON_API_EXPORTED PyIntegerAttribute
   static void bindDerived(ClassTy &c);
 
 private:
-  static int64_t toPyInt(PyIntegerAttribute &self);
+  static nanobind::object toPyInt(PyIntegerAttribute &self);
 };
 
 /// Bool Attribute subclass - BoolAttr.

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b2e9d9887e098..05c0c5e825df3 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <cmath>
 #include <cstdint>
 #include <optional>
 #include <string>
@@ -343,8 +344,50 @@ void PyFloatAttribute::bindDerived(ClassTy &c) {
 void PyIntegerAttribute::bindDerived(ClassTy &c) {
   c.def_static(
       "get",
-      [](PyType &type, int64_t value) {
-        MlirAttribute attr = mlirIntegerAttrGet(type, value);
+      [](PyType &type, nb::object value) {
+        // Handle IndexType - it doesn't have a bit width or signedness.
+        if (mlirTypeIsAIndex(type)) {
+          int64_t intValue = nb::cast<int64_t>(value);
+          MlirAttribute attr = mlirIntegerAttrGet(type, intValue);
+          return PyIntegerAttribute(type.getContext(), attr);
+        }
+
+        // Get the bit width of the integer type.
+        unsigned bitWidth = mlirIntegerTypeGetWidth(type);
+
+        // Try to use the fast path for small integers.
+        if (bitWidth <= 64) {
+          int64_t intValue = nb::cast<int64_t>(value);
+          MlirAttribute attr = mlirIntegerAttrGet(type, intValue);
+          return PyIntegerAttribute(type.getContext(), attr);
+        }
+
+        // For larger integers, convert Python int to array of 64-bit words.
+        unsigned numWords = std::ceil(static_cast<double>(bitWidth) / 64);
+        std::vector<uint64_t> words(numWords, 0);
+
+        // Extract words from Python integer (little-endian order).
+        nb::object mask = nb::int_(0xFFFFFFFFFFFFFFFFULL);
+        nb::object shift = nb::int_(64);
+        nb::object current = value;
+
+        // Handle negative numbers for signed types by converting to two's
+        // complement representation.
+        if (mlirIntegerTypeIsSigned(type)) {
+          nb::object zero = nb::int_(0);
+          if (nb::cast<bool>(current < zero)) {
+            nb::object twoToTheBitWidth = nb::int_(1) << nb::int_(bitWidth);
+            current = current + twoToTheBitWidth;
+          }
+        }
+
+        for (unsigned i = 0; i < numWords; ++i) {
+          words[i] = nb::cast<uint64_t>(current & mask);
+          current = current >> shift;
+        }
+
+        MlirAttribute attr =
+            mlirIntegerAttrGetFromWords(type, numWords, words.data());
         return PyIntegerAttribute(type.getContext(), attr);
       },
       nb::arg("type"), nb::arg("value"),
@@ -360,13 +403,44 @@ void PyIntegerAttribute::bindDerived(ClassTy &c) {
       nb::sig("def static_typeid(/) -> TypeID"));
 }
 
-int64_t PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
+nb::object PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
   MlirType type = mlirAttributeGetType(self);
-  if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
-    return mlirIntegerAttrGetValueInt(self);
-  if (mlirIntegerTypeIsSigned(type))
-    return mlirIntegerAttrGetValueSInt(self);
-  return mlirIntegerAttrGetValueUInt(self);
+  unsigned bitWidth = mlirIntegerAttrGetValueBitWidth(self);
+
+  // For integers that fit in 64 bits, use the fast path.
+  if (bitWidth <= 64) {
+    if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+      return nb::int_(mlirIntegerAttrGetValueInt(self));
+    if (mlirIntegerTypeIsSigned(type))
+      return nb::int_(mlirIntegerAttrGetValueSInt(self));
+    return nb::int_(mlirIntegerAttrGetValueUInt(self));
+  }
+
+  // For larger integers, reconstruct the value from raw words.
+  unsigned numWords = mlirIntegerAttrGetValueNumWords(self);
+  std::vector<uint64_t> words(numWords);
+  mlirIntegerAttrGetValueWords(self, words.data());
+
+  // Build the Python integer by shifting and ORing the words together.
+  // Words are in little-endian order (least significant first).
+  nb::object result = nb::int_(0);
+  nb::object shift = nb::int_(64);
+  for (unsigned i = numWords; i > 0; --i) {
+    result = result << shift;
+    result = result | nb::int_(words[i - 1]);
+  }
+
+  // Handle signed integers: if the sign bit is set, subtract 2^bitWidth.
+  if (mlirIntegerTypeIsSigned(type)) {
+    // Check if sign bit is set (most significant bit of the value).
+    bool signBitSet = (words[numWords - 1] >> ((bitWidth - 1) % 64)) & 1;
+    if (signBitSet) {
+      nb::object twoToTheBitWidth = nb::int_(1) << nb::int_(bitWidth);
+      result = result - twoToTheBitWidth;
+    }
+  }
+
+  return result;
 }
 
 void PyBoolAttribute::bindDerived(ClassTy &c) {

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index f7172c21a0cb9..f1d95afd31faa 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -175,6 +175,29 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
   return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
 }
 
+unsigned mlirIntegerAttrGetValueBitWidth(MlirAttribute attr) {
+  return llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getBitWidth();
+}
+
+unsigned mlirIntegerAttrGetValueNumWords(MlirAttribute attr) {
+  return llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getNumWords();
+}
+
+void mlirIntegerAttrGetValueWords(MlirAttribute attr, uint64_t *words) {
+  const APInt &value = llvm::cast<IntegerAttr>(unwrap(attr)).getValue();
+  unsigned numWords = value.getNumWords();
+  const uint64_t *rawData = value.getRawData();
+  std::copy(rawData, rawData + numWords, words);
+}
+
+MlirAttribute mlirIntegerAttrGetFromWords(MlirType type, unsigned numWords,
+                                          const uint64_t *words) {
+  Type mlirType = unwrap(type);
+  unsigned bitWidth = mlirType.getIntOrFloatBitWidth();
+  APInt value(bitWidth, ArrayRef<uint64_t>(words, numWords));
+  return wrap(IntegerAttr::get(mlirType, value));
+}
+
 MlirTypeID mlirIntegerAttrGetTypeID(void) {
   return wrap(IntegerAttr::getTypeID());
 }

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 5590834999261..3ba3788023293 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -749,3 +749,60 @@ def testAttrNames():
         print(StringAttr.attr_name)
         # CHECK: builtin.float
         print(FloatAttr.attr_name)
+
+
+# CHECK-LABEL: TEST: testLargeIntegerAttr
+ at run
+def testLargeIntegerAttr():
+    with Context():
+        # Test 128-bit unsigned integer
+        i128_type = IntegerType.get_unsigned(128)
+        large_value_128 = (1 << 127) + 12345
+        attr_128 = IntegerAttr.get(i128_type, large_value_128)
+        # CHECK: 128-bit value matches: True
+        print("128-bit value matches:", int(attr_128) == large_value_128)
+
+        # Test 256-bit unsigned integer (BN254 field modulus example)
+        i256_type = IntegerType.get_unsigned(256)
+        bn254_modulus = 21888242871839275222246405745257275088548364400416034343698204186575808495617
+        attr_256 = IntegerAttr.get(i256_type, bn254_modulus)
+        # CHECK: 256-bit value matches: True
+        print("256-bit value matches:", int(attr_256) == bn254_modulus)
+
+        # Test 256-bit signed integer (positive value)
+        si256_type = IntegerType.get_signed(256)
+        positive_signed = (1 << 200) + 999
+        attr_si256_pos = IntegerAttr.get(si256_type, positive_signed)
+        # CHECK: 256-bit signed positive matches: True
+        print(
+            "256-bit signed positive matches:", int(attr_si256_pos) == positive_signed
+        )
+
+        # Test 256-bit signed integer (negative value)
+        negative_signed = -((1 << 200) + 12345)
+        attr_si256_neg = IntegerAttr.get(si256_type, negative_signed)
+        # CHECK: 256-bit signed negative matches: True
+        print(
+            "256-bit signed negative matches:", int(attr_si256_neg) == negative_signed
+        )
+
+        # Test 64-bit boundary (should still work with fast path)
+        i64_type = IntegerType.get_signless(64)
+        value_64 = (1 << 63) - 1  # max signed 64-bit
+        attr_64 = IntegerAttr.get(i64_type, value_64)
+        # CHECK: 64-bit value matches: True
+        print("64-bit value matches:", int(attr_64) == value_64)
+
+        # Test edge case: 65-bit integer (just over 64-bit boundary)
+        i65_type = IntegerType.get_unsigned(65)
+        value_65 = (1 << 64) + 1
+        attr_65 = IntegerAttr.get(i65_type, value_65)
+        # CHECK: 65-bit value matches: True
+        print("65-bit value matches:", int(attr_65) == value_65)
+
+        # Test very large integer (512-bit)
+        i512_type = IntegerType.get_unsigned(512)
+        value_512 = (1 << 500) + (1 << 300) + (1 << 100) + 42
+        attr_512 = IntegerAttr.get(i512_type, value_512)
+        # CHECK: 512-bit value matches: True
+        print("512-bit value matches:", int(attr_512) == value_512)


        


More information about the Mlir-commits mailing list