[llvm] 0b81ff3 - [KnownBits] Handle shifts over wide types

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue May 16 02:26:49 PDT 2023


Author: Nikita Popov
Date: 2023-05-16T11:26:39+02:00
New Revision: 0b81ff3ac50ac4900033c0dde7518a7199d101ae

URL: https://github.com/llvm/llvm-project/commit/0b81ff3ac50ac4900033c0dde7518a7199d101ae
DIFF: https://github.com/llvm/llvm-project/commit/0b81ff3ac50ac4900033c0dde7518a7199d101ae.diff

LOG: [KnownBits] Handle shifts over wide types

Do not assert if the bit width is larger than 64 bits. This case
is currently hidden from the IR layer by other checks, but gets
exposed with future changes.

Added: 
    

Modified: 
    llvm/lib/Support/KnownBits.cpp
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 0fd3e5a5ad64a..ddeb6a4961601 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -195,8 +195,8 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
   if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
+    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
@@ -251,8 +251,8 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
   if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
+    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
@@ -312,8 +312,8 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
   if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
+    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 28f904e5b5e32..ece7e80147db8 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -352,6 +352,20 @@ TEST(KnownBitsTest, UnaryExhaustive) {
       [](const APInt &N) { return N * N; }, checkCorrectnessOnlyUnary);
 }
 
+TEST(KnownBitsTest, WideShifts) {
+  unsigned BitWidth = 128;
+  KnownBits Unknown(BitWidth);
+  KnownBits AllOnes = KnownBits::makeConstant(APInt::getAllOnes(BitWidth));
+
+  KnownBits ShlResult(BitWidth);
+  ShlResult.makeNegative();
+  EXPECT_EQ(KnownBits::shl(AllOnes, Unknown), ShlResult);
+  KnownBits LShrResult(BitWidth);
+  LShrResult.One.setBit(0);
+  EXPECT_EQ(KnownBits::lshr(AllOnes, Unknown), LShrResult);
+  EXPECT_EQ(KnownBits::ashr(AllOnes, Unknown), AllOnes);
+}
+
 TEST(KnownBitsTest, ICmpExhaustive) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {


        


More information about the llvm-commits mailing list