[llvm] [KnownBits] Make nuw and nsw support in computeForAddSub optimal (PR #83382)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 5 00:24:11 PST 2024


================
@@ -54,34 +54,89 @@ KnownBits KnownBits::computeForAddCarry(
       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
 }
 
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
-                                      const KnownBits &LHS, KnownBits RHS) {
-  KnownBits KnownOut;
-  if (Add) {
-    // Sum = LHS + RHS + 0
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
-  } else {
-    // Sum = LHS + ~RHS + 1
-    std::swap(RHS.Zero, RHS.One);
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
+                                      const KnownBits &LHS,
+                                      const KnownBits &RHS) {
+  unsigned BitWidth = LHS.getBitWidth();
+  KnownBits KnownOut(BitWidth);
+  // This can be a relatively expensive helper, so optimistically save some
+  // work.
+  if (LHS.isUnknown() && RHS.isUnknown())
+    return KnownOut;
+
+  if (!LHS.isUnknown() && !RHS.isUnknown()) {
+    if (Add) {
+      // Sum = LHS + RHS + 0
+      KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
+                                      /*CarryOne=*/false);
+    } else {
+      // Sum = LHS + ~RHS + 1
+      KnownBits NotRHS = RHS;
+      std::swap(NotRHS.Zero, NotRHS.One);
+      KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false,
+                                      /*CarryOne=*/true);
+    }
   }
 
-  // Are we still trying to solve for the sign bit?
-  if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
-    if (NSW) {
-      // Adding two non-negative numbers, or subtracting a negative number from
-      // a non-negative one, can't wrap into negative.
-      if (LHS.isNonNegative() && RHS.isNonNegative())
-        KnownOut.makeNonNegative();
-      // Adding two negative numbers, or subtracting a non-negative number from
-      // a negative one, can't wrap into non-negative.
-      else if (LHS.isNegative() && RHS.isNegative())
-        KnownOut.makeNegative();
+  // Handle add/sub given nsw and/or nuw.
+  if (NUW) {
+    if (Add) {
+      // (add nuw X, Y)
+      APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
+      // None of the adds can end up overflowing, so min consecutive highbits
+      // in minimum possible of X + Y must all remain set.
+      if (NSW) {
+        unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
+        // If we have NSW as well, we also no we can't overflow the signbit so
+        // can start counting from 1 bit back.
+        KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
----------------
jayfoad wrote:

This was the part of my refactoring that I was least happy about: for the nsw+nuw case we run this line AND line 94 AND all the nsw logic below, and they all seem to be necessary. I'm sure there's a simpler way of handling the nsw+nuw case but I couldn't find it.

https://github.com/llvm/llvm-project/pull/83382


More information about the llvm-commits mailing list