[llvm] 9d73a8b - [KnownBits] Make shl/lshr/ashr implementations optimal

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue May 16 00:44:41 PDT 2023


Author: Nikita Popov
Date: 2023-05-16T09:44:26+02:00
New Revision: 9d73a8bdc66496b673c11e991fd9cf0cba0a1bff

URL: https://github.com/llvm/llvm-project/commit/9d73a8bdc66496b673c11e991fd9cf0cba0a1bff
DIFF: https://github.com/llvm/llvm-project/commit/9d73a8bdc66496b673c11e991fd9cf0cba0a1bff.diff

LOG: [KnownBits] Make shl/lshr/ashr implementations optimal

The implementations for shifts were suboptimal in the case where
the max shift amount was >= bitwidth. In that case we should still
use the usual code clamped to BitWidth-1 rather than just giving up
entirely.

Additionally, there was an implementation bug where the known zero
bits for the individual shift amounts were not set in the shl/lshr
implementations. I think after these changes, we'll be able to drop
some of the code in ValueTracking which *also* evaluates all possible
shift amounts and has been papering over this issue.

For the "all poison" case I've opted to return an unknown value for
now. It would be better to return zero, but this has fairly
substantial test fallout, so I figured it's best to not mix it into
this change. (The "correct" return value would be a conflict, but
given that a lot of our APIs assert conflict-freedom, that's probably
not the best idea to actually return.)

Differential Revision: https://reviews.llvm.org/D150587

Added: 
    

Modified: 
    llvm/lib/Support/KnownBits.cpp
    llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll
    llvm/test/Transforms/InstCombine/not-add.ll
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ad6e1c8b7003b..3377dd346da82 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -182,24 +182,26 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
   // No matter the shift amount, the trailing zeros will stay zero.
   unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
 
-  // Minimum shift amount low bits are known zero.
   APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.ult(BitWidth)) {
-    MinTrailingZeros += MinShiftAmount.getZExtValue();
-    MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
-  }
+  if (MinShiftAmount.uge(BitWidth))
+    // Always poison. Return unknown because we don't like returning conflict.
+    return Known;
+
+  // Minimum shift amount low bits are known zero.
+  MinTrailingZeros += MinShiftAmount.getZExtValue();
+  MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
 
   // If the maximum shift is in range, then find the common bits from all
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
+  if (!LHS.isUnknown()) {
     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
+                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
       // Skip if the shift amount is impossible.
       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
@@ -207,6 +209,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
         continue;
       KnownBits SpecificShift;
       SpecificShift.Zero = LHS.Zero << ShiftAmt;
+      SpecificShift.Zero.setLowBits(ShiftAmt);
       SpecificShift.One = LHS.One << ShiftAmt;
       Known = KnownBits::commonBits(Known, SpecificShift);
       if (Known.isUnknown())
@@ -237,22 +240,24 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
 
   // Minimum shift amount high bits are known zero.
   APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.ult(BitWidth)) {
-    MinLeadingZeros += MinShiftAmount.getZExtValue();
-    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
-  }
+  if (MinShiftAmount.uge(BitWidth))
+    // Always poison. Return unknown because we don't like returning conflict.
+    return Known;
+
+  MinLeadingZeros += MinShiftAmount.getZExtValue();
+  MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
 
   // If the maximum shift is in range, then find the common bits from all
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
+  if (!LHS.isUnknown()) {
     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
+                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
       // Skip if the shift amount is impossible.
       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
@@ -260,6 +265,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
         continue;
       KnownBits SpecificShift = LHS;
       SpecificShift.Zero.lshrInPlace(ShiftAmt);
+      SpecificShift.Zero.setHighBits(ShiftAmt);
       SpecificShift.One.lshrInPlace(ShiftAmt);
       Known = KnownBits::commonBits(Known, SpecificShift);
       if (Known.isUnknown())
@@ -289,28 +295,30 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
 
   // Minimum shift amount high bits are known sign bits.
   APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.ult(BitWidth)) {
-    if (MinLeadingZeros) {
-      MinLeadingZeros += MinShiftAmount.getZExtValue();
-      MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
-    }
-    if (MinLeadingOnes) {
-      MinLeadingOnes += MinShiftAmount.getZExtValue();
-      MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
-    }
+  if (MinShiftAmount.uge(BitWidth))
+    // Always poison. Return unknown because we don't like returning conflict.
+    return Known;
+
+  if (MinLeadingZeros) {
+    MinLeadingZeros += MinShiftAmount.getZExtValue();
+    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+  }
+  if (MinLeadingOnes) {
+    MinLeadingOnes += MinShiftAmount.getZExtValue();
+    MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
   }
 
   // If the maximum shift is in range, then find the common bits from all
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
+  if (!LHS.isUnknown()) {
     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
+                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
       // Skip if the shift amount is impossible.
       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||

diff  --git a/llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll b/llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll
index c734639878ce3..452459a41fcff 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll
@@ -221,7 +221,7 @@ for.end:
 ; SI-PROMOTE-VECT: s_load_dword [[IDX:s[0-9]+]]
 ; SI-PROMOTE-VECT: s_lshl_b32 [[SCALED_IDX:s[0-9]+]], [[IDX]], 4
 ; SI-PROMOTE-VECT: s_lshr_b32 [[SREG:s[0-9]+]], 0x10000, [[SCALED_IDX]]
-; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 0xffff
+; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 1
 define amdgpu_kernel void @short_array(ptr addrspace(1) %out, i32 %index) #0 {
 entry:
   %0 = alloca [2 x i16], addrspace(5)

diff  --git a/llvm/test/Transforms/InstCombine/not-add.ll b/llvm/test/Transforms/InstCombine/not-add.ll
index 48cd4f537a8cf..03f4f445fa898 100644
--- a/llvm/test/Transforms/InstCombine/not-add.ll
+++ b/llvm/test/Transforms/InstCombine/not-add.ll
@@ -172,7 +172,7 @@ define void @pr50370(i32 %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X:%.*]], 1
 ; CHECK-NEXT:    [[B15:%.*]] = srem i32 ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537)), [[XOR]]
-; CHECK-NEXT:    [[B12:%.*]] = add nuw nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537))
+; CHECK-NEXT:    [[B12:%.*]] = add nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537))
 ; CHECK-NEXT:    [[B:%.*]] = xor i32 [[B12]], -1
 ; CHECK-NEXT:    store i32 [[B]], ptr undef, align 4
 ; CHECK-NEXT:    ret void

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index e8daae3685d74..28f904e5b5e32 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -270,7 +270,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       },
       checkCorrectnessOnlyBinary);
 
-  // TODO: Make optimal for non-constant cases.
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::shl(Known1, Known2);
@@ -279,9 +278,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         if (N2.uge(N2.getBitWidth()))
           return std::nullopt;
         return N1.shl(N2);
-      },
-      [](const KnownBits &, const KnownBits &Known) {
-        return Known.isConstant();
       });
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -291,9 +287,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         if (N2.uge(N2.getBitWidth()))
           return std::nullopt;
         return N1.lshr(N2);
-      },
-      [](const KnownBits &, const KnownBits &Known) {
-        return Known.isConstant();
       });
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -303,9 +296,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         if (N2.uge(N2.getBitWidth()))
           return std::nullopt;
         return N1.ashr(N2);
-      },
-      [](const KnownBits &, const KnownBits &Known) {
-        return Known.isConstant();
       });
 
   testBinaryOpExhaustive(


        


More information about the llvm-commits mailing list