[llvm] 2d48a77 - [KnownBits] Use early return for unknown LHS for shifts (NFC)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed May 24 02:04:32 PDT 2023


Author: Nikita Popov
Date: 2023-05-24T11:02:16+02:00
New Revision: 2d48a771fc00681414c54ea3936bff91f3b253c4

URL: https://github.com/llvm/llvm-project/commit/2d48a771fc00681414c54ea3936bff91f3b253c4
DIFF: https://github.com/llvm/llvm-project/commit/2d48a771fc00681414c54ea3936bff91f3b253c4.diff

LOG: [KnownBits] Use early return for unknown LHS for shifts (NFC)

Make it clear that the leading/trailing zeros handling is only
relevant for the unknown LHS case, which is a fast path to avoid
the full shift amount loop in cases where it would not produce
better results.

Added: 
    

Modified: 
    llvm/lib/Support/KnownBits.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 13d79a389fd6..c665f8a30597 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -179,9 +179,6 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
     return Known;
   }
 
-  // No matter the shift amount, the trailing zeros will stay zero.
-  unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
-
   APInt MinShiftAmount = RHS.getMinValue();
   if (MinShiftAmount.uge(BitWidth)) {
     // Always poison. Return zero because we don't like returning conflict.
@@ -189,37 +186,39 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
     return Known;
   }
 
-  // Minimum shift amount low bits are known zero.
-  MinTrailingZeros += MinShiftAmount.getZExtValue();
-  MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
+  if (LHS.isUnknown()) {
+    // No matter the shift amount, the trailing zeros will stay zero.
+    unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
+    // Minimum shift amount low bits are known zero.
+    MinTrailingZeros += MinShiftAmount.getZExtValue();
+    MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
+    Known.Zero.setLowBits(MinTrailingZeros);
+    return Known;
+  }
 
-  // If the maximum shift is in range, then find the common bits from all
-  // possible shifts.
+  // Find the common bits from all possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
-    assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
-    Known.Zero.setAllBits();
-    Known.One.setAllBits();
-    for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
-         ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
-      // Skip if the shift amount is impossible.
-      if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
-          (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
-        continue;
-      KnownBits SpecificShift;
-      SpecificShift.Zero = LHS.Zero << ShiftAmt;
-      SpecificShift.Zero.setLowBits(ShiftAmt);
-      SpecificShift.One = LHS.One << ShiftAmt;
-      Known = Known.intersectWith(SpecificShift);
-      if (Known.isUnknown())
-        break;
-    }
+  uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+  uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
+  assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+  Known.Zero.setAllBits();
+  Known.One.setAllBits();
+  for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
+                MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
+       ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+    // Skip if the shift amount is impossible.
+    if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+        (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
+      continue;
+    KnownBits SpecificShift;
+    SpecificShift.Zero = LHS.Zero << ShiftAmt;
+    SpecificShift.Zero.setLowBits(ShiftAmt);
+    SpecificShift.One = LHS.One << ShiftAmt;
+    Known = Known.intersectWith(SpecificShift);
+    if (Known.isUnknown())
+      break;
   }
 
-  Known.Zero.setLowBits(MinTrailingZeros);
   return Known;
 }
 
@@ -237,9 +236,6 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
     return Known;
   }
 
-  // No matter the shift amount, the leading zeros will stay zero.
-  unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
-
   // Minimum shift amount high bits are known zero.
   APInt MinShiftAmount = RHS.getMinValue();
   if (MinShiftAmount.uge(BitWidth)) {
@@ -248,36 +244,38 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
     return Known;
   }
 
-  MinLeadingZeros += MinShiftAmount.getZExtValue();
-  MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+  if (LHS.isUnknown()) {
+    // No matter the shift amount, the leading zeros will stay zero.
+    unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
+    MinLeadingZeros += MinShiftAmount.getZExtValue();
+    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+    Known.Zero.setHighBits(MinLeadingZeros);
+    return Known;
+  }
 
-  // If the maximum shift is in range, then find the common bits from all
-  // possible shifts.
+  // Find the common bits from all possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
-    assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
-    Known.Zero.setAllBits();
-    Known.One.setAllBits();
-    for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
-         ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
-      // Skip if the shift amount is impossible.
-      if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
-          (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
-        continue;
-      KnownBits SpecificShift = LHS;
-      SpecificShift.Zero.lshrInPlace(ShiftAmt);
-      SpecificShift.Zero.setHighBits(ShiftAmt);
-      SpecificShift.One.lshrInPlace(ShiftAmt);
-      Known = Known.intersectWith(SpecificShift);
-      if (Known.isUnknown())
-        break;
-    }
+  uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+  uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
+  assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+  Known.Zero.setAllBits();
+  Known.One.setAllBits();
+  for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
+                MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
+       ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+    // Skip if the shift amount is impossible.
+    if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+        (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
+      continue;
+    KnownBits SpecificShift = LHS;
+    SpecificShift.Zero.lshrInPlace(ShiftAmt);
+    SpecificShift.Zero.setHighBits(ShiftAmt);
+    SpecificShift.One.lshrInPlace(ShiftAmt);
+    Known = Known.intersectWith(SpecificShift);
+    if (Known.isUnknown())
+      break;
   }
 
-  Known.Zero.setHighBits(MinLeadingZeros);
   return Known;
 }
 
@@ -293,10 +291,6 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
     return Known;
   }
 
-  // No matter the shift amount, the leading sign bits will stay.
-  unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
-  unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
-
   // Minimum shift amount high bits are known sign bits.
   APInt MinShiftAmount = RHS.getMinValue();
   if (MinShiftAmount.uge(BitWidth)) {
@@ -305,42 +299,45 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
     return Known;
   }
 
-  if (MinLeadingZeros) {
-    MinLeadingZeros += MinShiftAmount.getZExtValue();
-    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
-  }
-  if (MinLeadingOnes) {
-    MinLeadingOnes += MinShiftAmount.getZExtValue();
-    MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
+  if (LHS.isUnknown()) {
+    // No matter the shift amount, the leading sign bits will stay.
+    unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
+    unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
+    if (MinLeadingZeros) {
+      MinLeadingZeros += MinShiftAmount.getZExtValue();
+      MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+    }
+    if (MinLeadingOnes) {
+      MinLeadingOnes += MinShiftAmount.getZExtValue();
+      MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
+    }
+    Known.Zero.setHighBits(MinLeadingZeros);
+    Known.One.setHighBits(MinLeadingOnes);
+    return Known;
   }
 
-  // If the maximum shift is in range, then find the common bits from all
-  // possible shifts.
+  // Find the common bits from all possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
-    assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
-    Known.Zero.setAllBits();
-    Known.One.setAllBits();
-    for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
-         ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
-      // Skip if the shift amount is impossible.
-      if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
-          (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
-        continue;
-      KnownBits SpecificShift = LHS;
-      SpecificShift.Zero.ashrInPlace(ShiftAmt);
-      SpecificShift.One.ashrInPlace(ShiftAmt);
-      Known = Known.intersectWith(SpecificShift);
-      if (Known.isUnknown())
-        break;
-    }
+  uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+  uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
+  assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+  Known.Zero.setAllBits();
+  Known.One.setAllBits();
+  for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
+                MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
+       ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+    // Skip if the shift amount is impossible.
+    if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+        (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
+      continue;
+    KnownBits SpecificShift = LHS;
+    SpecificShift.Zero.ashrInPlace(ShiftAmt);
+    SpecificShift.One.ashrInPlace(ShiftAmt);
+    Known = Known.intersectWith(SpecificShift);
+    if (Known.isUnknown())
+      break;
   }
 
-  Known.Zero.setHighBits(MinLeadingZeros);
-  Known.One.setHighBits(MinLeadingOnes);
   return Known;
 }
 


        


More information about the llvm-commits mailing list