[llvm] [KnownBits] Make nuw and nsw support in computeForAddSub optimal (PR #83382)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 05:48:04 PST 2024


================
@@ -54,32 +54,183 @@ KnownBits KnownBits::computeForAddCarry(
       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
 }
 
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
-                                      const KnownBits &LHS, KnownBits RHS) {
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
+                                      const KnownBits &LHS,
+                                      const KnownBits &RHS) {
+  // This can be a relatively expensive helper, so optimistically save some
+  // work.
+  if (LHS.isUnknown() && RHS.isUnknown())
+    return LHS;
   KnownBits KnownOut;
   if (Add) {
     // Sum = LHS + RHS + 0
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
+    KnownOut =
+        ::computeForAddCarry(LHS, RHS, /*CarryZero*/ true, /*CarryOne*/ false);
   } else {
     // Sum = LHS + ~RHS + 1
-    std::swap(RHS.Zero, RHS.One);
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+    KnownBits NotRHS = RHS;
+    std::swap(NotRHS.Zero, NotRHS.One);
+    KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
+                                    /*CarryOne*/ true);
   }
+  if (!NSW && !NUW)
+    return KnownOut;
 
-  // Are we still trying to solve for the sign bit?
-  if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
-    if (NSW) {
-      // Adding two non-negative numbers, or subtracting a negative number from
-      // a non-negative one, can't wrap into negative.
-      if (LHS.isNonNegative() && RHS.isNonNegative())
+  auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
+                            const KnownBits &R, bool &OV) {
+    APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
+    APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
+
+    if (ForNSW) {
+      LVal.clearSignBit();
+      RVal.clearSignBit();
+    }
+    APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
+    if (ForNSW) {
+      OV = Res.isSignBitSet();
+      Res.clearSignBit();
+      if (Res.getBitWidth() > 1 && Res[Res.getBitWidth() - 2])
+        Res.setSignBit();
+    }
+    return Res;
+  };
+
+  auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax=*/true, L, R, OV);
+  };
+
+  auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax=*/false, L, R, OV);
+  };
+
+  auto ForceNegative = [](KnownBits &Known) {
+    Known.Zero.clearSignBit();
+    Known.One.setSignBit();
+  };
+
+  auto ForcePositive = [](KnownBits &Known) {
+    Known.One.clearSignBit();
+    Known.Zero.setSignBit();
+  };
+
+  // Handle add/sub given nsw and/or nuw.
+  //
+  // Possible TODO: Add/Sub implementations mirror one another in many ways.
+  // They could probably be compressed into a single implementation of roughly
+  // half the total LOC. Leaving seperate for now to increase clarity.
+  // NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
+  // deducing bits based on the known sign result.
+  if (Add) {
+    if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
+      bool OverflowMin;
+      APInt MinVal;
+      if (NSW) {
+        MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+        // (add nsw nuw) or (add nsw PosX, PosY)
+
+        // None of the adds can end up overflowing, so min consecutive
+        // highbits in minimum possible of X + Y must all remain set.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+
+        // NSW and Positive arguments leads to positive result.
+        if (LHS.isNonNegative() && RHS.isNonNegative())
+          ForcePositive(KnownOut);
+      }
+      if (NUW) {
+        KnownOut.One.clearSignBit();
+        // (add nuw X, Y)
+        MinVal = GetMinVal(/*ForNSW=*/false, LHS, RHS, OverflowMin);
+        // Same as (add nsw PosX, PosY), basically since we can't overflow,
+        // the high bits of minimum possible X + Y must remain set.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+      }
+    } else if (LHS.isNegative() && RHS.isNegative()) {
+      bool OverflowMax;
+      APInt MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+      // (add nsw NegX, NegY)
+
+      // We need to re-overflow the signbit, so we are looking for sequence
+      // of 0s from consecutive overflows.
+      KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+      ForceNegative(KnownOut);
+    } else if (!KnownOut.isSignUnknown()) {
+      // Pass, avoid extra work if we already know the sign bit.
+    } else if (LHS.isNonNegative() || RHS.isNonNegative()) {
+      bool OverflowMin;
+      (void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+      // (add nsw PosX, ?Y)
+
+      // If the minimal possible of X + Y overflows the signbit, then Y must
+      // have been signed (which will cause unsigned overflow otherwise nsw
+      // will be violated) leading to unsigned result.
+      if (OverflowMin)
         KnownOut.makeNonNegative();
-      // Adding two negative numbers, or subtracting a non-negative number from
-      // a negative one, can't wrap into non-negative.
-      else if (LHS.isNegative() && RHS.isNegative())
+    } else if (LHS.isNegative() || RHS.isNegative()) {
+      bool OverflowMax;
+      (void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+      // (add nsw NegX, ?Y)
+
+      // If the maximum possible of X + Y doesn't overflows the signbit,
+      // then Y must have been unsigned (otherwise nsw violated) so NegX +
+      // PosY w.o overflowing the signbit results in Negative.
+      if (!OverflowMax)
         KnownOut.makeNegative();
     }
+  } else {
+    if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
+      bool OverflowMax;
+      APInt MaxVal;
+      if (NSW) {
+        MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+        // (sub nsw nuw) or (sub nsw NegX, PosY)
+
+        // None of the subs can overflow at any point, so any common high bits
+        // will subtract away and result in zeros.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+        if (LHS.isNegative() && RHS.isNonNegative())
+          ForceNegative(KnownOut);
+      }
+      if (NUW) {
+        KnownOut.Zero.clearSignBit();
+        // (sub nuw X, Y)
+        MaxVal = GetMaxVal(/*ForNSW=*/false, LHS, RHS, OverflowMax);
+
+        // Basically all common high bits between X/Y will cancel out as
+        // leading zeros.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+      }
+    } else if (LHS.isNonNegative() && RHS.isNegative()) {
+      bool OverflowMin;
+      APInt MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+      // (sub nsw PosX, NegY)
+
+      // Opposite case of above, we must "re-overflow" the signbit, so
+      // minimal set of high bits will be fixed.
+      KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+      ForcePositive(KnownOut);
+    } else if (!KnownOut.isSignUnknown()) {
+      // Pass, avoid extra work if we already know the sign bit.
+    } else if (LHS.isNegative() || RHS.isNonNegative()) {
+      bool OverflowMax;
+      (void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+      // (sub nsw NegX/?X, ?Y/PosY)
+      if (OverflowMax)
+        KnownOut.makeNegative();
+    } else if (LHS.isNonNegative() || RHS.isNegative()) {
+      bool OverflowMin;
+      (void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+      // (sub nsw PosX/?X, ?Y/NegY)
+      if (!OverflowMin)
+        KnownOut.makeNonNegative();
+    }
+  }
+
+  // Just return 0 if the nsw/nuw is violated and we have poison.
+  if (KnownOut.hasConflict()) {
+    KnownOut.setAllZero();
+    return KnownOut;
----------------
jayfoad wrote:

Nit: don't need this return.

https://github.com/llvm/llvm-project/pull/83382


More information about the llvm-commits mailing list