[llvm] [Knowbits] Make nuw and nsw support in computeForAddSub optimal (PR #83382)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 28 23:10:38 PST 2024
================
@@ -63,23 +63,173 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
} else {
// Sum = LHS + ~RHS + 1
- std::swap(RHS.Zero, RHS.One);
- KnownOut = ::computeForAddCarry(
- LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+ KnownBits NotRHS = RHS;
+ std::swap(NotRHS.Zero, NotRHS.One);
+ KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
+ /*CarryOne*/ true);
+ }
+ if (!NSW && !NUW)
+ return KnownOut;
+
+ // We truncate out the signbit during nsw handling so just handle this special
+ // case to avoid dealing with it later.
+ if (LHS.getBitWidth() == 1) {
+ return LHS | RHS;
}
- // Are we still trying to solve for the sign bit?
- if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
+ auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
+ APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
+
+ if (ForNSW) {
+ LVal = LVal.trunc(LVal.getBitWidth() - 1);
+ RVal = RVal.trunc(RVal.getBitWidth() - 1);
+ }
+ APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
+ if (ForNSW)
+ Res = Res.sext(Res.getBitWidth() + 1);
+ return Res;
+ };
+
+ auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ return GetMinMaxVal(ForNSW, /*ForMax*/ true, L, R, OV);
+ };
+
+ auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ return GetMinMaxVal(ForNSW, /*ForMax*/ false, L, R, OV);
+ };
+
+ std::optional<bool> Negative;
+ bool Poison = false;
+ // Handle add/sub given nsw and/or nuw.
+ //
+ // Possible TODO: Add/Sub implementations mirror one another in many ways.
+ // They could probably be compressed into a single implementation of roughly
+ // half the total LOC. Leaving seperate for now to increase clarity.
+ // NB: We handle NSW by truncating sign bits then deducing bits based on
+ // the known sign result.
+ if (Add) {
if (NSW) {
- // Adding two non-negative numbers, or subtracting a negative number from
- // a non-negative one, can't wrap into negative.
- if (LHS.isNonNegative() && RHS.isNonNegative())
- KnownOut.makeNonNegative();
- // Adding two negative numbers, or subtracting a non-negative number from
- // a negative one, can't wrap into non-negative.
- else if (LHS.isNegative() && RHS.isNegative())
- KnownOut.makeNegative();
+ bool OverflowMax, OverflowMin;
+ APInt MaxVal = GetMaxVal(/*ForNSW*/ true, LHS, RHS, OverflowMax);
+ APInt MinVal = GetMinVal(/*ForNSW*/ true, LHS, RHS, OverflowMin);
+
+ if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
+ // (add nuw) or (add nsw PosX, PosY)
+
+ // None of the adds can end up overflowing, so min consecutive highbits
+ // in minimum possible of X + Y must all remain set.
+ KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+
+ // NSW and Positive arguments leads to positive result.
+ if (LHS.isNonNegative() && RHS.isNonNegative())
+ Negative = false;
+ else
+ KnownOut.One.clearSignBit();
+
+ Poison = OverflowMin;
+ } else if (LHS.isNegative() && RHS.isNegative()) {
+ // (add nsw NegX, NegY)
+
+ // We need to re-overflow the signbit, so we are looking for sequence of
+ // 0s from consecutive overflows.
+ KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+ Negative = true;
+ Poison = !OverflowMax;
+ } else if (LHS.isNonNegative() || RHS.isNonNegative()) {
+ // (add nsw PosX, ?Y)
+
+ // If the minimal possible of X + Y overflows the signbit, then Y must
+ // have been signed (which will cause unsigned overflow otherwise nsw
+ // will be violated) leading to unsigned result.
+ if (OverflowMin)
+ Negative = false;
+ } else if (LHS.isNegative() || RHS.isNegative()) {
+ // (add nsw NegX, ?Y)
+
+ // If the maximum possible of X + Y doesn't overflows the signbit, then
+ // Y must have been unsigned (otherwise nsw violated) so NegX + PosY w.o
+ // overflowing the signbit results in Negative.
+ if (!OverflowMax)
+ Negative = true;
+ }
}
+ if (NUW) {
+ // (add nuw X, Y)
+ bool OverflowMax, OverflowMin;
+ APInt MaxVal = GetMaxVal(/*ForNSW*/ false, LHS, RHS, OverflowMax);
+ APInt MinVal = GetMinVal(/*ForNSW*/ false, LHS, RHS, OverflowMin);
----------------
arsenm wrote:
```suggestion
APInt MaxVal = GetMaxVal(/*ForNSW=*/ false, LHS, RHS, OverflowMax);
APInt MinVal = GetMinVal(/*ForNSW=*/ false, LHS, RHS, OverflowMin);
```
https://github.com/llvm/llvm-project/pull/83382
More information about the llvm-commits
mailing list