[Mlir-commits] [mlir] [MLIR] Validate APInt bitwidth in IntegerAttr::get(Type, APInt) (PR #188725)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 26 04:26:00 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

IntegerAttr::get(Type, APInt) did not validate that the APInt's bit width matched the expected bit width for the given type. For integer types, the APInt width must equal the integer type's width. For index types, the APInt width must equal IndexType::kInternalStorageBitWidth (64 bits).

Passing an APInt with the wrong bit width could cause a non-deterministic crash in StorageUniquer when comparing two IntegerAttr instances for the same type but with different APInt widths.

This commit adds assertions in the get(Type, APInt) builder to catch such misuse early in debug builds, providing a clear error message at the call site rather than a cryptic crash in the storage uniquer.

Fixes #<!-- -->56401

Assisted-by: Claude Code

---
Full diff: https://github.com/llvm/llvm-project/pull/188725.diff


2 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+9) 
- (modified) mlir/unittests/IR/AttributeTest.cpp (+42) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0cc556ef5d852..a0d72d082b5fb 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -720,6 +720,15 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
                                         "const APInt &":$value), [{
       if (type.isSignlessInteger(1))
         return BoolAttr::get(type.getContext(), value.getBoolValue());
+      // Validate that the APInt has the correct bit width for the given type.
+      if (auto intTy = ::llvm::dyn_cast<IntegerType>(type)) {
+        assert(value.getBitWidth() == intTy.getWidth() &&
+               "IntegerAttr::get: APInt bit width must match integer type width");
+      } else if (::llvm::isa<IndexType>(type)) {
+        assert(value.getBitWidth() == IndexType::kInternalStorageBitWidth &&
+               "IntegerAttr::get: APInt bit width must match IndexType internal "
+               "storage bit width");
+      }
       return $_get(type.getContext(), type, value);
     }]>,
     AttrBuilder<(ins "const APSInt &":$value), [{
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 404aa8c0dcf3d..900cacabd592e 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -523,4 +523,46 @@ TEST(CopyCountAttr, PrintStripped) {
   EXPECT_EQ(str, "|#test.copy_count<hello>|[copy_count<hello>]");
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
+TEST(IntegerAttrTest, CorrectBitWidths) {
+  MLIRContext context;
+
+  // Correct APInt width for i32.
+  IntegerType i32 = IntegerType::get(&context, 32);
+  IntegerAttr attr32 = IntegerAttr::get(i32, APInt(32, 42));
+  EXPECT_EQ(attr32.getType(), i32);
+  EXPECT_EQ(attr32.getValue().getBitWidth(), 32u);
+  EXPECT_EQ(attr32.getInt(), 42);
+
+  // Correct APInt width for index type.
+  IndexType indexTy = IndexType::get(&context);
+  IntegerAttr attrIdx =
+      IntegerAttr::get(indexTy, APInt(IndexType::kInternalStorageBitWidth, 5));
+  EXPECT_EQ(attrIdx.getType(), indexTy);
+  EXPECT_EQ(attrIdx.getValue().getBitWidth(),
+            (unsigned)IndexType::kInternalStorageBitWidth);
+}
+
+#ifndef NDEBUG
+TEST(IntegerAttrDeathTest, WrongBitWidthForIntegerType) {
+  MLIRContext context;
+  IntegerType i32 = IntegerType::get(&context, 32);
+  // APInt(8, 1) has bit width 8, but i32 requires 32.
+  EXPECT_DEATH(IntegerAttr::get(i32, APInt(8, 1)),
+               "APInt bit width must match integer type width");
+}
+
+TEST(IntegerAttrDeathTest, WrongBitWidthForIndexType) {
+  MLIRContext context;
+  IndexType indexTy = IndexType::get(&context);
+  // APInt(1, 1) has bit width 1, but index type requires 64.
+  EXPECT_DEATH(
+      IntegerAttr::get(indexTy, APInt(1, 1)),
+      "APInt bit width must match IndexType internal storage bit width");
+}
+#endif // NDEBUG
+
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/188725


More information about the Mlir-commits mailing list