[llvm] [APInt] Assert correct values in APInt constructor (PR #80309)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 09:10:52 PST 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/80309

If the uint64_t constructor is used, assert that the value is actuall a signed or unsigned N-bit integer depending on whether the isSigned flag is set.

Currently, we allow values to be silently truncated, which is a constant source of subtle bugs -- a particularly common mistake is to create -1 values without setting the isSigned flag, which will work fine for all common bit widths (<= 64-bit) and miscompile for larger integers.

(This is a draft because I have not fixed up all uses failing the new assertion yet.)

>From f56c096f01861faaec2fffb0ad2a090c5c2dfdcf Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 1 Feb 2024 17:27:46 +0100
Subject: [PATCH] [APInt] Assert correct values in APInt constructor

If the uint64_t constructor is used, assert that the value is
actuall a signed or unsigned N-bit integer depending on whether
the isSigned flag is set.

Currently, we allow values to be silently truncated, which is
a constant source of subtle bugs -- a particularly common mistake
is to create -1 values without setting the isSigned flag, which
will work fine for all common bit widths (<= 64-bit) and miscompile
for larger integers.
---
 llvm/include/llvm/ADT/APInt.h                        | 12 +++++++++++-
 llvm/lib/Support/APInt.cpp                           |  6 +++---
 .../Transforms/InstCombine/InstCombineCompares.cpp   |  2 +-
 .../lib/Transforms/InstCombine/InstCombineSelect.cpp |  8 ++++----
 llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp       |  8 +++++---
 5 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 6f2f25548cc84..3a9fb6032f4cb 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -108,9 +108,19 @@ class [[nodiscard]] APInt {
   /// \param isSigned how to treat signedness of val
   APInt(unsigned numBits, uint64_t val, bool isSigned = false)
       : BitWidth(numBits) {
+    if (BitWidth == 0) {
+      assert(val == 0 && "Value must be zero for 0-bit APInt");
+    } else if (isSigned) {
+      assert(llvm::isIntN(BitWidth, val) &&
+             "Value is not an N-bit signed value");
+    } else {
+      assert(llvm::isUIntN(BitWidth, val) &&
+             "Value is not an N-bit unsigned value");
+    }
     if (isSingleWord()) {
       U.VAL = val;
-      clearUnusedBits();
+      if (isSigned)
+        clearUnusedBits();
     } else {
       initSlowCase(val, isSigned);
     }
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 05b1526da95ff..313c874f9caf1 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -234,7 +234,7 @@ APInt& APInt::operator-=(uint64_t RHS) {
 APInt APInt::operator*(const APInt& RHS) const {
   assert(BitWidth == RHS.BitWidth && "Bit widths must be the same");
   if (isSingleWord())
-    return APInt(BitWidth, U.VAL * RHS.U.VAL);
+    return APInt(BitWidth, (U.VAL * RHS.U.VAL) & maxUIntN(BitWidth));
 
   APInt Result(getMemory(getNumWords()), getBitWidth());
   tcMultiply(Result.U.pVal, U.pVal, RHS.U.pVal, getNumWords());
@@ -907,7 +907,7 @@ APInt APInt::trunc(unsigned width) const {
   assert(width <= BitWidth && "Invalid APInt Truncate request");
 
   if (width <= APINT_BITS_PER_WORD)
-    return APInt(width, getRawData()[0]);
+    return APInt(width, getRawData()[0] & maxUIntN(width));
 
   if (width == BitWidth)
     return *this;
@@ -955,7 +955,7 @@ APInt APInt::sext(unsigned Width) const {
   assert(Width >= BitWidth && "Invalid APInt SignExtend request");
 
   if (Width <= APINT_BITS_PER_WORD)
-    return APInt(Width, SignExtend64(U.VAL, BitWidth));
+    return APInt(Width, SignExtend64(U.VAL, BitWidth), /*isSigned*/ true);
 
   if (Width == BitWidth)
     return *this;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index d295853798b80..b3b425cc1ac42 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -308,7 +308,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
       DL.getTypeAllocSize(Init->getType()->getArrayElementType());
   auto MaskIdx = [&](Value *Idx) {
     if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) {
-      Value *Mask = ConstantInt::get(Idx->getType(), -1);
+      Value *Mask = Constant::getAllOnesValue(Idx->getType());
       Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize));
       Idx = Builder.CreateAnd(Idx, Mask);
     }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 453e4d788705f..efc3b99e83606 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -665,11 +665,11 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
   Value *X, *Y;
   unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits();
   if ((Pred != ICmpInst::ICMP_SGT ||
-       !match(CmpRHS,
-              m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) &&
+       !match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE,
+                                         APInt::getAllOnes(Bitwidth)))) &&
       (Pred != ICmpInst::ICMP_SLT ||
-       !match(CmpRHS,
-              m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0)))))
+       !match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE,
+                                         APInt::getZero(Bitwidth)))))
     return nullptr;
 
   // Canonicalize so that ashr is in FalseVal.
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 52eef9ab58a4d..f24d7a1295c10 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -553,7 +553,8 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
   // strcmp(x, y)  -> cnst  (if both x and y are constant strings)
   if (HasStr1 && HasStr2)
     return ConstantInt::get(CI->getType(),
-                            std::clamp(Str1.compare(Str2), -1, 1));
+                            std::clamp(Str1.compare(Str2), -1, 1),
+                            /*isSigned*/ true);
 
   if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x
     return B.CreateNeg(B.CreateZExt(
@@ -638,7 +639,8 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
     StringRef SubStr1 = substr(Str1, Length);
     StringRef SubStr2 = substr(Str2, Length);
     return ConstantInt::get(CI->getType(),
-                            std::clamp(SubStr1.compare(SubStr2), -1, 1));
+                            std::clamp(SubStr1.compare(SubStr2), -1, 1),
+                            /*isSigned*/ true);
   }
 
   if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x
@@ -1518,7 +1520,7 @@ static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS,
   int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1;
   Value *MaxSize = ConstantInt::get(Size->getType(), Pos);
   Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize);
-  Value *Res = ConstantInt::get(CI->getType(), IRes);
+  Value *Res = ConstantInt::get(CI->getType(), IRes, /*isSigned*/ true);
   return B.CreateSelect(Cmp, Zero, Res);
 }
 



More information about the llvm-commits mailing list