[llvm] 2d48a77 - [KnownBits] Use early return for unknown LHS for shifts (NFC)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Wed May 24 02:04:32 PDT 2023
Author: Nikita Popov
Date: 2023-05-24T11:02:16+02:00
New Revision: 2d48a771fc00681414c54ea3936bff91f3b253c4
URL: https://github.com/llvm/llvm-project/commit/2d48a771fc00681414c54ea3936bff91f3b253c4
DIFF: https://github.com/llvm/llvm-project/commit/2d48a771fc00681414c54ea3936bff91f3b253c4.diff
LOG: [KnownBits] Use early return for unknown LHS for shifts (NFC)
Make it clear that the leading/trailing zeros handling is only
relevant for the unknown LHS case, which is a fast path to avoid
the full shift amount loop in cases where it would not produce
better results.
Added:
Modified:
llvm/lib/Support/KnownBits.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 13d79a389fd6..c665f8a30597 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -179,9 +179,6 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
return Known;
}
- // No matter the shift amount, the trailing zeros will stay zero.
- unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
-
APInt MinShiftAmount = RHS.getMinValue();
if (MinShiftAmount.uge(BitWidth)) {
// Always poison. Return zero because we don't like returning conflict.
@@ -189,37 +186,39 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
return Known;
}
- // Minimum shift amount low bits are known zero.
- MinTrailingZeros += MinShiftAmount.getZExtValue();
- MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
+ if (LHS.isUnknown()) {
+ // No matter the shift amount, the trailing zeros will stay zero.
+ unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
+ // Minimum shift amount low bits are known zero.
+ MinTrailingZeros += MinShiftAmount.getZExtValue();
+ MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
+ Known.Zero.setLowBits(MinTrailingZeros);
+ return Known;
+ }
- // If the maximum shift is in range, then find the common bits from all
- // possible shifts.
+ // Find the common bits from all possible shifts.
APInt MaxShiftAmount = RHS.getMaxValue();
- if (!LHS.isUnknown()) {
- uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
- uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
- assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
- Known.Zero.setAllBits();
- Known.One.setAllBits();
- for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
- MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
- ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
- // Skip if the shift amount is impossible.
- if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
- (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
- continue;
- KnownBits SpecificShift;
- SpecificShift.Zero = LHS.Zero << ShiftAmt;
- SpecificShift.Zero.setLowBits(ShiftAmt);
- SpecificShift.One = LHS.One << ShiftAmt;
- Known = Known.intersectWith(SpecificShift);
- if (Known.isUnknown())
- break;
- }
+ uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+ uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
+ assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+ Known.Zero.setAllBits();
+ Known.One.setAllBits();
+ for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
+ MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
+ ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+ // Skip if the shift amount is impossible.
+ if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+ (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
+ continue;
+ KnownBits SpecificShift;
+ SpecificShift.Zero = LHS.Zero << ShiftAmt;
+ SpecificShift.Zero.setLowBits(ShiftAmt);
+ SpecificShift.One = LHS.One << ShiftAmt;
+ Known = Known.intersectWith(SpecificShift);
+ if (Known.isUnknown())
+ break;
}
- Known.Zero.setLowBits(MinTrailingZeros);
return Known;
}
@@ -237,9 +236,6 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
return Known;
}
- // No matter the shift amount, the leading zeros will stay zero.
- unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
-
// Minimum shift amount high bits are known zero.
APInt MinShiftAmount = RHS.getMinValue();
if (MinShiftAmount.uge(BitWidth)) {
@@ -248,36 +244,38 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
return Known;
}
- MinLeadingZeros += MinShiftAmount.getZExtValue();
- MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+ if (LHS.isUnknown()) {
+ // No matter the shift amount, the leading zeros will stay zero.
+ unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
+ MinLeadingZeros += MinShiftAmount.getZExtValue();
+ MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+ Known.Zero.setHighBits(MinLeadingZeros);
+ return Known;
+ }
- // If the maximum shift is in range, then find the common bits from all
- // possible shifts.
+ // Find the common bits from all possible shifts.
APInt MaxShiftAmount = RHS.getMaxValue();
- if (!LHS.isUnknown()) {
- uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
- uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
- assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
- Known.Zero.setAllBits();
- Known.One.setAllBits();
- for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
- MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
- ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
- // Skip if the shift amount is impossible.
- if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
- (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
- continue;
- KnownBits SpecificShift = LHS;
- SpecificShift.Zero.lshrInPlace(ShiftAmt);
- SpecificShift.Zero.setHighBits(ShiftAmt);
- SpecificShift.One.lshrInPlace(ShiftAmt);
- Known = Known.intersectWith(SpecificShift);
- if (Known.isUnknown())
- break;
- }
+ uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+ uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
+ assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+ Known.Zero.setAllBits();
+ Known.One.setAllBits();
+ for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
+ MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
+ ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+ // Skip if the shift amount is impossible.
+ if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+ (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
+ continue;
+ KnownBits SpecificShift = LHS;
+ SpecificShift.Zero.lshrInPlace(ShiftAmt);
+ SpecificShift.Zero.setHighBits(ShiftAmt);
+ SpecificShift.One.lshrInPlace(ShiftAmt);
+ Known = Known.intersectWith(SpecificShift);
+ if (Known.isUnknown())
+ break;
}
- Known.Zero.setHighBits(MinLeadingZeros);
return Known;
}
@@ -293,10 +291,6 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
return Known;
}
- // No matter the shift amount, the leading sign bits will stay.
- unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
- unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
-
// Minimum shift amount high bits are known sign bits.
APInt MinShiftAmount = RHS.getMinValue();
if (MinShiftAmount.uge(BitWidth)) {
@@ -305,42 +299,45 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
return Known;
}
- if (MinLeadingZeros) {
- MinLeadingZeros += MinShiftAmount.getZExtValue();
- MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
- }
- if (MinLeadingOnes) {
- MinLeadingOnes += MinShiftAmount.getZExtValue();
- MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
+ if (LHS.isUnknown()) {
+ // No matter the shift amount, the leading sign bits will stay.
+ unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
+ unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
+ if (MinLeadingZeros) {
+ MinLeadingZeros += MinShiftAmount.getZExtValue();
+ MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+ }
+ if (MinLeadingOnes) {
+ MinLeadingOnes += MinShiftAmount.getZExtValue();
+ MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
+ }
+ Known.Zero.setHighBits(MinLeadingZeros);
+ Known.One.setHighBits(MinLeadingOnes);
+ return Known;
}
- // If the maximum shift is in range, then find the common bits from all
- // possible shifts.
+ // Find the common bits from all possible shifts.
APInt MaxShiftAmount = RHS.getMaxValue();
- if (!LHS.isUnknown()) {
- uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
- uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
- assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
- Known.Zero.setAllBits();
- Known.One.setAllBits();
- for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
- MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
- ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
- // Skip if the shift amount is impossible.
- if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
- (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
- continue;
- KnownBits SpecificShift = LHS;
- SpecificShift.Zero.ashrInPlace(ShiftAmt);
- SpecificShift.One.ashrInPlace(ShiftAmt);
- Known = Known.intersectWith(SpecificShift);
- if (Known.isUnknown())
- break;
- }
+ uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+ uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
+ assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+ Known.Zero.setAllBits();
+ Known.One.setAllBits();
+ for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
+ MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
+ ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+ // Skip if the shift amount is impossible.
+ if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+ (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
+ continue;
+ KnownBits SpecificShift = LHS;
+ SpecificShift.Zero.ashrInPlace(ShiftAmt);
+ SpecificShift.One.ashrInPlace(ShiftAmt);
+ Known = Known.intersectWith(SpecificShift);
+ if (Known.isUnknown())
+ break;
}
- Known.Zero.setHighBits(MinLeadingZeros);
- Known.One.setHighBits(MinLeadingOnes);
return Known;
}
More information about the llvm-commits
mailing list