[clang] [llvm] [ConstantRange] Estimate tighter lower (upper) bounds for masked binary and (or) (PR #120352)
Stephen Senran Zhang via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 19 00:39:45 PST 2024
================
@@ -1520,15 +1520,102 @@ ConstantRange ConstantRange::binaryNot() const {
return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this);
}
+/// Estimate the 'bit-masked AND' operation's lower bound.
+///
+/// E.g., given two ranges as follows (single quotes are separators and
+/// have no meaning here),
+///
+/// LHS = [10'001'010, ; LLo
+/// 10'100'000] ; LHi
+/// RHS = [10'111'010, ; RLo
+/// 10'111'100] ; RHi
+///
+/// we know that the higher 2 bits of the result is always '10'; and note that
+/// there's at least one bit is 1 in LHS[3:6] (since the range is continuous),
+/// and all bits in RHS[3:6] are 1, so we know the lower bound of the result is
+/// 10'001'000.
+///
+/// The algorithm is as follows,
+/// 1. we first calculate a mask to mask out the higher common bits by
+/// Mask = (LLo ^ LHi) | (LLo ^ LHi) | (LLo ^ RLo);
+/// Mask = set all non-leading-zero bits to 1 for Mask;
+/// 2. find the bit field with at least 1 in LHS (i.e., bit 3:6 in the example)
+/// after applying the mask, with
+/// StartBit = BitWidth - (LLo & Mask).clz() - 1;
+/// EndBit = BitWidth - (LHi & Mask).clz();
+/// 3. check if all bits in [StartBit:EndBit] in RHS are 1, and all bits of
+/// RLo and RHi in [StartBit:BitWidth] are same, and if so, the lower bound
+/// can be updated to
+/// LowerBound = LLo & Keep;
+/// where Keep is a mask to mask out trailing bits (the lower 3 bits in the
+/// example);
+/// 4. repeat the step 2 and 3 with LHS and RHS swapped, and update the lower
+/// bound with the smaller one.
+static APInt estimateBitMaskedAndLowerBound(const ConstantRange &LHS,
+ const ConstantRange &RHS) {
+ auto BitWidth = LHS.getBitWidth();
+ // If either is full set or unsigned wrapped, then the range must contain '0'
+ // which leads the lower bound to 0.
+ if ((LHS.isFullSet() || RHS.isFullSet()) ||
+ (LHS.isWrappedSet() || RHS.isWrappedSet()))
+ return APInt::getZero(BitWidth);
+
+ auto LLo = LHS.getLower();
+ auto LHi = LHS.getUpper() - 1;
+ auto RLo = RHS.getLower();
+ auto RHi = RHS.getUpper() - 1;
+
+ // Calculate the mask that mask out the higher common bits.
+ auto Mask = (LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo);
+ unsigned LeadingZeros = Mask.countLeadingZeros();
+ Mask.setLowBits(BitWidth - LeadingZeros);
+
+ auto estimateBound =
+ [BitWidth, &Mask](const APInt &ALo, const APInt &AHi, const APInt &BLo,
+ const APInt &BHi) -> std::optional<APInt> {
+ unsigned LeadingZeros = (ALo & Mask).countLeadingZeros();
+ if (LeadingZeros == BitWidth)
+ return std::nullopt;
+
+ unsigned StartBit = BitWidth - LeadingZeros - 1;
+
+ if (BLo.extractBits(BitWidth - StartBit, StartBit) !=
+ BHi.extractBits(BitWidth - StartBit, StartBit))
+ return std::nullopt;
+
+ unsigned EndBit = BitWidth - (AHi & Mask).countLeadingZeros();
+ if (!(BLo.extractBits(EndBit - StartBit, StartBit) &
+ BHi.extractBits(EndBit - StartBit, StartBit))
+ .isAllOnes())
+ return std::nullopt;
+
+ APInt Keep(BitWidth, 0);
+ Keep.setBits(StartBit, BitWidth);
+ return Keep & ALo;
+ };
+
+ auto LowerBoundByLHS = estimateBound(LLo, LHi, RLo, RHi);
+ auto LowerBoundByRHS = estimateBound(RLo, RHi, LLo, LHi);
+
+ if (LowerBoundByLHS && LowerBoundByRHS)
+ return LowerBoundByLHS->ult(*LowerBoundByRHS) ? *LowerBoundByLHS
----------------
zsrkmyn wrote:
Great catch!
https://github.com/llvm/llvm-project/pull/120352
More information about the llvm-commits
mailing list