[llvm] 6f75c66 - [KnownBits] Add fast-path for shl with unknown shift amount (NFC)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri May 26 04:57:46 PDT 2023


Author: Nikita Popov
Date: 2023-05-26T13:57:33+02:00
New Revision: 6f75c6681d47164072daea54dde2727c51b0e739

URL: https://github.com/llvm/llvm-project/commit/6f75c6681d47164072daea54dde2727c51b0e739
DIFF: https://github.com/llvm/llvm-project/commit/6f75c6681d47164072daea54dde2727c51b0e739.diff

LOG: [KnownBits] Add fast-path for shl with unknown shift amount (NFC)

We currently don't call into KnownBits::shl() from ValueTracking
if the shift amount is unknown. If we do try to do so, we get
significant compile-time regressions, because evaluating all 64
shift amounts if quite expensive, and mostly pointless in this case.
Add a fast-path for the case where the shift amount is the full
[0, BitWidth-1] range. This primarily requires a more accurate
estimate of the max shift amount, to avoid taking the fast-path in
too many cases.

Differential Revision: https://reviews.llvm.org/D151540

Added: 
    

Modified: 
    llvm/lib/Support/KnownBits.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 8491ac5846e78..c8e4a8981666a 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -164,41 +164,33 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
   return Flip(umax(Flip(LHS), Flip(RHS)));
 }
 
+static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
+  if (isPowerOf2_32(BitWidth))
+    return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
+  // This is only an approximate upper bound.
+  return MaxValue.getLimitedValue(BitWidth - 1);
+}
+
 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
                          bool NSW) {
   unsigned BitWidth = LHS.getBitWidth();
-  auto ShiftByConst = [&](const KnownBits &LHS,
-                          unsigned ShiftAmt) -> std::optional<KnownBits> {
+  auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known;
     bool ShiftedOutZero, ShiftedOutOne;
     Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
     Known.Zero.setLowBits(ShiftAmt);
     Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
 
-    if (NUW) {
-      if (ShiftedOutOne)
-        // One bit has been shifted out.
-        return std::nullopt;
-      if (ShiftAmt != 0)
+    // All cases returning poison have been handled by MaxShiftAmount already.
+    if (NSW) {
+      if (NUW && ShiftAmt != 0)
         // NUW means we can assume anything shifted out was a zero.
         ShiftedOutZero = true;
-    }
 
-    if (NSW) {
-      if (ShiftedOutZero && ShiftedOutOne)
-        // Both zeros and ones have been shifted out.
-        return std::nullopt;
-      if (ShiftedOutZero) {
-        if (Known.isNegative())
-          // Zero bit has been shifted out, but result sign is negative.
-          return std::nullopt;
+      if (ShiftedOutZero)
         Known.makeNonNegative();
-      } else if (ShiftedOutOne) {
-        if (Known.isNonNegative())
-          // One bit has been shifted out, but result sign is non-negative.
-          return std::nullopt;
+      else if (ShiftedOutOne)
         Known.makeNegative();
-      }
     }
     return Known;
   };
@@ -218,8 +210,34 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
     return Known;
   }
 
+  // Determine maximum shift amount, taking NUW/NSW flags into account.
+  APInt MaxValue = RHS.getMaxValue();
+  unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+  if (NUW && NSW)
+    MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1);
+  if (NUW)
+    MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros());
+  if (NSW)
+    MaxShiftAmount = std::min(
+        MaxShiftAmount,
+        std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1);
+
+  // Fast path for common case where the shift amount is unknown.
+  if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
+      isPowerOf2_32(BitWidth)) {
+    Known.Zero.setLowBits(LHS.countMinTrailingZeros());
+    if (LHS.isAllOnes())
+      Known.One.setSignBit();
+    if (NSW) {
+      if (LHS.isNonNegative())
+        Known.makeNonNegative();
+      if (LHS.isNegative())
+        Known.makeNegative();
+    }
+    return Known;
+  }
+
   // Find the common bits from all possible shifts.
-  unsigned MaxShiftAmount = RHS.getMaxValue().getLimitedValue(BitWidth - 1);
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
@@ -230,11 +248,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
       continue;
-    auto Res = ShiftByConst(LHS, ShiftAmt);
-    if (!Res)
-      // All larger shift amounts will overflow as well.
-      break;
-    Known = Known.intersectWith(*Res);
+    Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
     if (Known.isUnknown())
       break;
   }


        


More information about the llvm-commits mailing list