[llvm] [Knowbits] Make nuw and nsw support in computeForAddSub optimal (PR #83382)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 28 23:10:38 PST 2024


================
@@ -63,23 +63,173 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
         LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
   } else {
     // Sum = LHS + ~RHS + 1
-    std::swap(RHS.Zero, RHS.One);
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+    KnownBits NotRHS = RHS;
+    std::swap(NotRHS.Zero, NotRHS.One);
+    KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
+                                    /*CarryOne*/ true);
+  }
+  if (!NSW && !NUW)
+    return KnownOut;
+
+  // We truncate out the signbit during nsw handling so just handle this special
+  // case to avoid dealing with it later.
+  if (LHS.getBitWidth() == 1) {
+    return LHS | RHS;
   }
 
-  // Are we still trying to solve for the sign bit?
-  if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
+  auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
+                            const KnownBits &R, bool &OV) {
+    APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
+    APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
+
+    if (ForNSW) {
+      LVal = LVal.trunc(LVal.getBitWidth() - 1);
+      RVal = RVal.trunc(RVal.getBitWidth() - 1);
+    }
+    APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
+    if (ForNSW)
+      Res = Res.sext(Res.getBitWidth() + 1);
+    return Res;
+  };
+
+  auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax*/ true, L, R, OV);
+  };
+
+  auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax*/ false, L, R, OV);
+  };
+
+  std::optional<bool> Negative;
+  bool Poison = false;
+  // Handle add/sub given nsw and/or nuw.
+  //
+  // Possible TODO: Add/Sub implementations mirror one another in many ways.
+  // They could probably be compressed into a single implementation of roughly
+  // half the total LOC. Leaving seperate for now to increase clarity.
+  // NB: We handle NSW by truncating sign bits then deducing bits based on
+  // the known sign result.
+  if (Add) {
     if (NSW) {
-      // Adding two non-negative numbers, or subtracting a negative number from
-      // a non-negative one, can't wrap into negative.
-      if (LHS.isNonNegative() && RHS.isNonNegative())
-        KnownOut.makeNonNegative();
-      // Adding two negative numbers, or subtracting a non-negative number from
-      // a negative one, can't wrap into non-negative.
-      else if (LHS.isNegative() && RHS.isNegative())
-        KnownOut.makeNegative();
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ true, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ true, LHS, RHS, OverflowMin);
+
+      if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
+        // (add nuw) or (add nsw PosX, PosY)
+
+        // None of the adds can end up overflowing, so min consecutive highbits
+        // in minimum possible of X + Y must all remain set.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+
+        // NSW and Positive arguments leads to positive result.
+        if (LHS.isNonNegative() && RHS.isNonNegative())
+          Negative = false;
+        else
+          KnownOut.One.clearSignBit();
+
+        Poison = OverflowMin;
+      } else if (LHS.isNegative() && RHS.isNegative()) {
+        // (add nsw NegX, NegY)
+
+        // We need to re-overflow the signbit, so we are looking for sequence of
+        // 0s from consecutive overflows.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+        Negative = true;
+        Poison = !OverflowMax;
+      } else if (LHS.isNonNegative() || RHS.isNonNegative()) {
+        // (add nsw PosX, ?Y)
+
+        // If the minimal possible of X + Y overflows the signbit, then Y must
+        // have been signed (which will cause unsigned overflow otherwise nsw
+        // will be violated) leading to unsigned result.
+        if (OverflowMin)
+          Negative = false;
+      } else if (LHS.isNegative() || RHS.isNegative()) {
+        // (add nsw NegX, ?Y)
+
+        // If the maximum possible of X + Y doesn't overflows the signbit, then
+        // Y must have been unsigned (otherwise nsw violated) so NegX + PosY w.o
+        // overflowing the signbit results in Negative.
+        if (!OverflowMax)
+          Negative = true;
+      }
     }
+    if (NUW) {
+        // (add nuw X, Y)
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ false, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ false, LHS, RHS, OverflowMin);
+      // Same as (add nsw PosX, PosY), basically since we can't overflow, the
+      // high bits of minimum possible X + Y must remain set.
+      KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+      Poison = OverflowMin;
+    }
+  } else {
+    if (NSW) {
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ true, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ true, LHS, RHS, OverflowMin);
+      if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
+        // (sub nuw) or (sub nsw NegX, PosY)
+
+        // None of the subs can overflow at any point, so any common high bits
+        // will subtract away and result in zeros.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+        if (LHS.isNegative() && RHS.isNonNegative())
+          Negative = true;
+        else
+          KnownOut.Zero.clearSignBit();
+
+        Poison = OverflowMax;
+      } else if (LHS.isNonNegative() && RHS.isNegative()) {
+        // (sub nsw PosX, NegY)
+        Negative = false;
+
+        // Opposite case of above, we must "re-overflow" the signbit, so minimal
+        // set of high bits will be fixed.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+        Poison = !OverflowMin;
+      } else if (LHS.isNegative() || RHS.isNonNegative()) {
+        // (sub nsw NegX/?X, ?Y/PosY)
+        if (OverflowMax)
+          Negative = true;
+      } else if (LHS.isNonNegative() || RHS.isNegative()) {
+        // (sub nsw PosX/?X, ?Y/NegY)
+        if (!OverflowMin)
+          Negative = false;
+      }
+    }
+    if (NUW) {
+      // (sub nuw X, Y)
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ false, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ false, LHS, RHS, OverflowMin);
----------------
arsenm wrote:

```suggestion
      APInt MaxVal = GetMaxVal(/*ForNSW=*/ false, LHS, RHS, OverflowMax);
      APInt MinVal = GetMinVal(/*ForNSW=*/ false, LHS, RHS, OverflowMin);
```

https://github.com/llvm/llvm-project/pull/83382


More information about the llvm-commits mailing list