[llvm] [KnownBits] Make nuw and nsw support in computeForAddSub optimal (PR #83382)
Jay Foad via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 4 05:48:04 PST 2024
================
@@ -54,32 +54,183 @@ KnownBits KnownBits::computeForAddCarry(
LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
}
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
- const KnownBits &LHS, KnownBits RHS) {
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
+ const KnownBits &LHS,
+ const KnownBits &RHS) {
+ // This can be a relatively expensive helper, so optimistically save some
+ // work.
+ if (LHS.isUnknown() && RHS.isUnknown())
+ return LHS;
KnownBits KnownOut;
if (Add) {
// Sum = LHS + RHS + 0
- KnownOut = ::computeForAddCarry(
- LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
+ KnownOut =
+ ::computeForAddCarry(LHS, RHS, /*CarryZero*/ true, /*CarryOne*/ false);
} else {
// Sum = LHS + ~RHS + 1
- std::swap(RHS.Zero, RHS.One);
- KnownOut = ::computeForAddCarry(
- LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+ KnownBits NotRHS = RHS;
+ std::swap(NotRHS.Zero, NotRHS.One);
+ KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
+ /*CarryOne*/ true);
}
+ if (!NSW && !NUW)
+ return KnownOut;
- // Are we still trying to solve for the sign bit?
- if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
- if (NSW) {
- // Adding two non-negative numbers, or subtracting a negative number from
- // a non-negative one, can't wrap into negative.
- if (LHS.isNonNegative() && RHS.isNonNegative())
+ auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
+ APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
+
+ if (ForNSW) {
+ LVal.clearSignBit();
+ RVal.clearSignBit();
+ }
+ APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
+ if (ForNSW) {
+ OV = Res.isSignBitSet();
+ Res.clearSignBit();
+ if (Res.getBitWidth() > 1 && Res[Res.getBitWidth() - 2])
+ Res.setSignBit();
+ }
+ return Res;
+ };
+
+ auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ return GetMinMaxVal(ForNSW, /*ForMax=*/true, L, R, OV);
+ };
+
+ auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ return GetMinMaxVal(ForNSW, /*ForMax=*/false, L, R, OV);
+ };
+
+ auto ForceNegative = [](KnownBits &Known) {
+ Known.Zero.clearSignBit();
+ Known.One.setSignBit();
+ };
+
+ auto ForcePositive = [](KnownBits &Known) {
+ Known.One.clearSignBit();
+ Known.Zero.setSignBit();
+ };
+
+ // Handle add/sub given nsw and/or nuw.
+ //
+ // Possible TODO: Add/Sub implementations mirror one another in many ways.
+ // They could probably be compressed into a single implementation of roughly
+ // half the total LOC. Leaving seperate for now to increase clarity.
+ // NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
+ // deducing bits based on the known sign result.
+ if (Add) {
+ if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
+ bool OverflowMin;
+ APInt MinVal;
+ if (NSW) {
+ MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+ // (add nsw nuw) or (add nsw PosX, PosY)
+
+ // None of the adds can end up overflowing, so min consecutive
+ // highbits in minimum possible of X + Y must all remain set.
+ KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+
+ // NSW and Positive arguments leads to positive result.
+ if (LHS.isNonNegative() && RHS.isNonNegative())
+ ForcePositive(KnownOut);
+ }
+ if (NUW) {
+ KnownOut.One.clearSignBit();
+ // (add nuw X, Y)
+ MinVal = GetMinVal(/*ForNSW=*/false, LHS, RHS, OverflowMin);
+ // Same as (add nsw PosX, PosY), basically since we can't overflow,
+ // the high bits of minimum possible X + Y must remain set.
+ KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+ }
+ } else if (LHS.isNegative() && RHS.isNegative()) {
+ bool OverflowMax;
+ APInt MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+ // (add nsw NegX, NegY)
+
+ // We need to re-overflow the signbit, so we are looking for sequence
+ // of 0s from consecutive overflows.
+ KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+ ForceNegative(KnownOut);
+ } else if (!KnownOut.isSignUnknown()) {
+ // Pass, avoid extra work if we already know the sign bit.
+ } else if (LHS.isNonNegative() || RHS.isNonNegative()) {
+ bool OverflowMin;
+ (void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+ // (add nsw PosX, ?Y)
+
+ // If the minimal possible of X + Y overflows the signbit, then Y must
+ // have been signed (which will cause unsigned overflow otherwise nsw
+ // will be violated) leading to unsigned result.
+ if (OverflowMin)
KnownOut.makeNonNegative();
- // Adding two negative numbers, or subtracting a non-negative number from
- // a negative one, can't wrap into non-negative.
- else if (LHS.isNegative() && RHS.isNegative())
+ } else if (LHS.isNegative() || RHS.isNegative()) {
+ bool OverflowMax;
+ (void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+ // (add nsw NegX, ?Y)
+
+ // If the maximum possible of X + Y doesn't overflows the signbit,
+ // then Y must have been unsigned (otherwise nsw violated) so NegX +
+ // PosY w.o overflowing the signbit results in Negative.
+ if (!OverflowMax)
KnownOut.makeNegative();
}
+ } else {
+ if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
+ bool OverflowMax;
+ APInt MaxVal;
+ if (NSW) {
+ MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+ // (sub nsw nuw) or (sub nsw NegX, PosY)
+
+ // None of the subs can overflow at any point, so any common high bits
+ // will subtract away and result in zeros.
+ KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+ if (LHS.isNegative() && RHS.isNonNegative())
+ ForceNegative(KnownOut);
+ }
+ if (NUW) {
+ KnownOut.Zero.clearSignBit();
+ // (sub nuw X, Y)
+ MaxVal = GetMaxVal(/*ForNSW=*/false, LHS, RHS, OverflowMax);
+
+ // Basically all common high bits between X/Y will cancel out as
+ // leading zeros.
+ KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+ }
+ } else if (LHS.isNonNegative() && RHS.isNegative()) {
+ bool OverflowMin;
+ APInt MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+ // (sub nsw PosX, NegY)
+
+ // Opposite case of above, we must "re-overflow" the signbit, so
+ // minimal set of high bits will be fixed.
+ KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+ ForcePositive(KnownOut);
+ } else if (!KnownOut.isSignUnknown()) {
+ // Pass, avoid extra work if we already know the sign bit.
+ } else if (LHS.isNegative() || RHS.isNonNegative()) {
+ bool OverflowMax;
+ (void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+ // (sub nsw NegX/?X, ?Y/PosY)
+ if (OverflowMax)
+ KnownOut.makeNegative();
+ } else if (LHS.isNonNegative() || RHS.isNegative()) {
+ bool OverflowMin;
+ (void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+ // (sub nsw PosX/?X, ?Y/NegY)
+ if (!OverflowMin)
+ KnownOut.makeNonNegative();
+ }
+ }
+
+ // Just return 0 if the nsw/nuw is violated and we have poison.
+ if (KnownOut.hasConflict()) {
+ KnownOut.setAllZero();
+ return KnownOut;
----------------
jayfoad wrote:
Nit: don't need this return.
https://github.com/llvm/llvm-project/pull/83382
More information about the llvm-commits
mailing list