[llvm] 2b1678c - [KnownBits] Simplify shl. NFCI.

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Thu May 25 08:05:41 PDT 2023


Author: Jay Foad
Date: 2023-05-25T16:05:32+01:00
New Revision: 2b1678cd06e6e5a770f5511abaf72a511236087e

URL: https://github.com/llvm/llvm-project/commit/2b1678cd06e6e5a770f5511abaf72a511236087e
DIFF: https://github.com/llvm/llvm-project/commit/2b1678cd06e6e5a770f5511abaf72a511236087e.diff

LOG: [KnownBits] Simplify shl. NFCI.

Differential Revision: https://reviews.llvm.org/D151421

Added: 
    

Modified: 
    llvm/lib/Support/KnownBits.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 8bb236baf4ae..8491ac5846e7 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -168,30 +168,34 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
                          bool NSW) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS,
-                          uint64_t ShiftAmt) -> std::optional<KnownBits> {
+                          unsigned ShiftAmt) -> std::optional<KnownBits> {
     KnownBits Known;
-    Known.Zero = LHS.Zero << ShiftAmt;
+    bool ShiftedOutZero, ShiftedOutOne;
+    Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
     Known.Zero.setLowBits(ShiftAmt);
-    Known.One = LHS.One << ShiftAmt;
-    if ((!NUW && !NSW) || ShiftAmt == 0)
-      return Known;
+    Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
+
+    if (NUW) {
+      if (ShiftedOutOne)
+        // One bit has been shifted out.
+        return std::nullopt;
+      if (ShiftAmt != 0)
+        // NUW means we can assume anything shifted out was a zero.
+        ShiftedOutZero = true;
+    }
 
-    KnownBits ShiftedOutBits = LHS.extractBits(ShiftAmt, BitWidth - ShiftAmt);
-    if (NUW && !ShiftedOutBits.One.isZero())
-      // One bit has been shifted out.
-      return std::nullopt;
     if (NSW) {
-      if (!ShiftedOutBits.Zero.isZero() && !ShiftedOutBits.One.isZero())
+      if (ShiftedOutZero && ShiftedOutOne)
         // Both zeros and ones have been shifted out.
         return std::nullopt;
-      if (NUW || !ShiftedOutBits.Zero.isZero()) {
+      if (ShiftedOutZero) {
         if (Known.isNegative())
           // Zero bit has been shifted out, but result sign is negative.
           return std::nullopt;
         Known.makeNonNegative();
-      } else if (!ShiftedOutBits.One.isZero()) {
+      } else if (ShiftedOutOne) {
         if (Known.isNonNegative())
-          // One bit has been shifted out, but result sign is negative.
+          // One bit has been shifted out, but result sign is non-negative.
           return std::nullopt;
         Known.makeNegative();
       }
@@ -199,47 +203,31 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
     return Known;
   };
 
-  // If the shift amount is a valid constant then transform LHS directly.
-  if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
-    if (auto Res = ShiftByConst(LHS, RHS.getConstant().getZExtValue()))
-      return *Res;
-    KnownBits Known(BitWidth);
-    Known.setAllZero();
-    return Known;
-  }
-
+  // Fast path for a common case when LHS is completely unknown.
   KnownBits Known(BitWidth);
-  APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.uge(BitWidth)) {
-    // Always poison. Return zero because we don't like returning conflict.
-    Known.setAllZero();
-    return Known;
-  }
-
+  unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
   if (LHS.isUnknown()) {
-    // No matter the shift amount, the trailing zeros will stay zero.
-    unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
-    // Minimum shift amount low bits are known zero.
-    MinTrailingZeros += MinShiftAmount.getZExtValue();
-    MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
-    Known.Zero.setLowBits(MinTrailingZeros);
-    if (NUW && NSW && !MinShiftAmount.isZero())
+    if (MinShiftAmount == BitWidth) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    Known.Zero.setLowBits(MinShiftAmount);
+    if (NUW && NSW && MinShiftAmount != 0)
       Known.makeNonNegative();
     return Known;
   }
 
   // Find the common bits from all possible shifts.
-  APInt MaxShiftAmount = RHS.getMaxValue();
-  uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
-  uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
-  assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+  unsigned MaxShiftAmount = RHS.getMaxValue().getLimitedValue(BitWidth - 1);
+  unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
+  unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
   Known.One.setAllBits();
-  for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
-       ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+  for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
+       ++ShiftAmt) {
     // Skip if the shift amount is impossible.
-    if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+    if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
       continue;
     auto Res = ShiftByConst(LHS, ShiftAmt);
@@ -252,10 +240,8 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
   }
 
   // All shift amounts may result in poison.
-  if (Known.hasConflict()) {
-    assert((NUW || NSW) && "Can only happen with nowrap flags");
+  if (Known.hasConflict())
     Known.setAllZero();
-  }
   return Known;
 }
 


        


More information about the llvm-commits mailing list