[llvm] goldsteinn/avg add sub opt knownbits (PR #110329)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 27 14:19:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-support

Author: None (goldsteinn)

<details>
<summary>Changes</summary>

- **[KnownBits] Mark `avg{Ceil,Floor}U` as optimal in exhaustive test; NFC**
- **[KnownBits] Make `avg{Ceil,Floor}S` optimal**
- **[KnownBits] Make `{s,u}{add,sub}_sat` optimal**


---
Full diff: https://github.com/llvm/llvm-project/pull/110329.diff


3 Files Affected:

- (modified) llvm/lib/Support/KnownBits.cpp (+113-72) 
- (modified) llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll (+1-8) 
- (modified) llvm/unittests/Support/KnownBitsTest.cpp (+12-18) 


``````````diff
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 6863c5c0af5dca..89e4b108b83e55 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -610,28 +610,78 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
                                      const KnownBits &RHS) {
   // We don't see NSW even for sadd/ssub as we want to check if the result has
   // signed overflow.
-  KnownBits Res =
-      KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
-  unsigned BitWidth = Res.getBitWidth();
-  auto SignBitKnown = [&](const KnownBits &K) {
-    return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
-  };
-  std::optional<bool> Overflow;
+  unsigned BitWidth = LHS.getBitWidth();
 
+  std::optional<bool> Overflow;
+  // Even if we can't entirely rule out overflow, we may be able to rule out
+  // overflow in one direction. This allows us to potentially keep some of the
+  // add/sub bits. I.e if we can't overflow in the positive direction we won't
+  // clamp to INT_MAX so we can keep low 0s from the add/sub result.
+  bool MayNegClamp = true;
+  bool MayPosClamp = true;
   if (Signed) {
-    // If we can actually detect overflow do so. Otherwise leave Overflow as
-    // nullopt (we assume it may have happened).
-    if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
+    // Easy cases we can rule out any overflow.
+    if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
+                (LHS.isNonNegative() && RHS.isNegative())))
+      Overflow = false;
+    else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
+                       (LHS.isNonNegative() && RHS.isNonNegative()))))
+      Overflow = false;
+    else {
+      // Check if we may overflow. If we can't rule out overflow then check if
+      // we can rule out a direction at least.
+      KnownBits UnsignedLHS = LHS;
+      KnownBits UnsignedRHS = RHS;
+      UnsignedLHS.One.clearSignBit();
+      UnsignedLHS.Zero.setSignBit();
+      UnsignedRHS.One.clearSignBit();
+      UnsignedRHS.Zero.setSignBit();
+      KnownBits Res =
+          KnownBits::computeForAddSub(Add, /*NSW=*/false,
+                                      /*NUW=*/false, UnsignedLHS, UnsignedRHS);
       if (Add) {
-        // sadd.sat
-        Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
-                    Res.isNonNegative() != LHS.isNonNegative());
+        if (Res.isNegative()) {
+          // Only overflow scenario is Pos + Pos.
+          MayNegClamp = false;
+          // Pos + Pos will overflow with extra signbit.
+          if (LHS.isNonNegative() && RHS.isNonNegative())
+            Overflow = true;
+        } else if (Res.isNonNegative()) {
+          // Only overflow scenario is Neg + Neg
+          MayPosClamp = false;
+          // Neg + Neg will overflow without extra signbit.
+          if (LHS.isNegative() && RHS.isNegative())
+            Overflow = true;
+        }
+        // We will never clamp to the opposite sign of N-bit result.
+        if (LHS.isNegative() || RHS.isNegative())
+          MayPosClamp = false;
+        if (LHS.isNonNegative() || RHS.isNonNegative())
+          MayNegClamp = false;
       } else {
-        // ssub.sat
-        Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
-                    Res.isNonNegative() != LHS.isNonNegative());
+        if (Res.isNegative()) {
+          // Only overflow scenario is Neg - Pos.
+          MayPosClamp = false;
+          // Neg - Pos will overflow with extra signbit.
+          if (LHS.isNegative() && RHS.isNonNegative())
+            Overflow = true;
+        } else if (Res.isNonNegative()) {
+          // Only overflow scenario is Pos - Neg.
+          MayNegClamp = false;
+          // Pos - Neg will overflow without extra signbit.
+          if (LHS.isNonNegative() && RHS.isNegative())
+            Overflow = true;
+        }
+        // We will never clamp to the opposite sign of N-bit result.
+        if (LHS.isNegative() || RHS.isNonNegative())
+          MayPosClamp = false;
+        if (LHS.isNonNegative() || RHS.isNegative())
+          MayNegClamp = false;
       }
     }
+    // If we have ruled out all clamping, we will never overflow.
+    if (!MayNegClamp && !MayPosClamp)
+      Overflow = false;
   } else if (Add) {
     // uadd.sat
     bool Of;
@@ -656,52 +706,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
     }
   }
 
-  if (Signed) {
-    if (Add) {
-      if (LHS.isNonNegative() && RHS.isNonNegative()) {
-        // Pos + Pos -> Pos
-        Res.One.clearSignBit();
-        Res.Zero.setSignBit();
-      }
-      if (LHS.isNegative() && RHS.isNegative()) {
-        // Neg + Neg -> Neg
-        Res.One.setSignBit();
-        Res.Zero.clearSignBit();
-      }
-    } else {
-      if (LHS.isNegative() && RHS.isNonNegative()) {
-        // Neg - Pos -> Neg
-        Res.One.setSignBit();
-        Res.Zero.clearSignBit();
-      } else if (LHS.isNonNegative() && RHS.isNegative()) {
-        // Pos - Neg -> Pos
-        Res.One.clearSignBit();
-        Res.Zero.setSignBit();
-      }
-    }
-  } else {
-    // Add: Leading ones of either operand are preserved.
-    // Sub: Leading zeros of LHS and leading ones of RHS are preserved
-    // as leading zeros in the result.
-    unsigned LeadingKnown;
-    if (Add)
-      LeadingKnown =
-          std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
-    else
-      LeadingKnown =
-          std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
-
-    // We select between the operation result and all-ones/zero
-    // respectively, so we can preserve known ones/zeros.
-    APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
-    if (Add) {
-      Res.One |= Mask;
-      Res.Zero &= ~Mask;
-    } else {
-      Res.Zero |= Mask;
-      Res.One &= ~Mask;
-    }
-  }
+  KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
+                                              /*NUW*/ !Signed, LHS, RHS);
 
   if (Overflow) {
     // We know whether or not we overflowed.
@@ -714,8 +720,9 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
     APInt C;
     if (Signed) {
       // sadd.sat / ssub.sat
-      assert(SignBitKnown(LHS) &&
-             "We somehow know overflow without knowing input sign");
+      assert(LHS.isNegative() ||
+             LHS.isNonNegative() &&
+                 "We somehow know overflow without knowing input sign");
       C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
                            : APInt::getSignedMaxValue(BitWidth);
     } else if (Add) {
@@ -735,8 +742,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
   if (Signed) {
     // sadd.sat/ssub.sat
     // We can keep our information about the sign bits.
-    Res.Zero.clearLowBits(BitWidth - 1);
-    Res.One.clearLowBits(BitWidth - 1);
+    if (MayPosClamp)
+      Res.Zero.clearLowBits(BitWidth - 1);
+    if (MayNegClamp)
+      Res.One.clearLowBits(BitWidth - 1);
   } else if (Add) {
     // uadd.sat
     // We need to clear all the known zeros as we can only use the leading ones.
@@ -766,12 +775,44 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
 static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
                             bool IsSigned) {
   unsigned BitWidth = LHS.getBitWidth();
-  LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
-  RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
-  LHS =
-      computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
-  LHS = LHS.extractBits(BitWidth, 1);
-  return LHS;
+  KnownBits ExtLHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
+  KnownBits ExtRHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
+  KnownBits Res = computeForAddCarry(ExtLHS, ExtRHS, /*CarryZero*/ !IsCeil,
+                                     /*CarryOne*/ IsCeil);
+  Res = Res.extractBits(BitWidth, 1);
+
+  auto SignBitKnown = [BitWidth](KnownBits KB) {
+    return KB.One[BitWidth - 1] || KB.Zero[BitWidth - 1];
+  };
+  // If we have only 1 known signbit between LHS/RHS we can try to figure
+  // out result signbit.
+  // NB: If we know both signbits `computeForAddCarry` gets the optimal result
+  // already.
+  if (IsSigned && !SignBitKnown(Res) &&
+      SignBitKnown(LHS) != SignBitKnown(RHS)) {
+    if (SignBitKnown(RHS))
+      std::swap(LHS, RHS);
+    KnownBits UnsignedLHS = LHS;
+    KnownBits UnsignedRHS = RHS;
+    UnsignedLHS.One.clearSignBit();
+    UnsignedLHS.Zero.setSignBit();
+    UnsignedRHS.One.clearSignBit();
+    UnsignedRHS.Zero.setSignBit();
+    KnownBits ResOf =
+        computeForAddCarry(UnsignedLHS, UnsignedRHS, /*CarryZero*/ !IsCeil,
+                           /*CarryOne*/ IsCeil);
+    // Assuming no overflow (which is the case since we extend the addition when
+    // taking the average):
+	// Neg + Neg -> Neg
+	// Neg + Pos -> Neg if the signbit doesn't overflow
+    if (LHS.isNegative() && ResOf.isNonNegative())
+      Res.makeNegative();
+	// Pos + Pos -> Pos
+	// Pos + Neg -> Pos if the signbit does overflow
+    else if (LHS.isNonNegative() && ResOf.isNegative())
+      Res.makeNonNegative();
+  }
+  return Res;
 }
 
 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
index c2926eaffa58c5..f9618e1ddbc022 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
 
 define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
 ; CHECK-LABEL: @ssub_sat_fail_may_overflow(
-; CHECK-NEXT:    [[XX:%.*]] = and i8 [[X:%.*]], 15
-; CHECK-NEXT:    [[YY:%.*]] = and i8 [[Y:%.*]], 15
-; CHECK-NEXT:    [[LHS:%.*]] = or i8 [[XX]], 1
-; CHECK-NEXT:    [[RHS:%.*]] = and i8 [[YY]], -2
-; CHECK-NEXT:    [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[EXP]], 1
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 false
 ;
   %xx = and i8 %x, 15
   %yy = and i8 %y, 15
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b6e16f809ea779..e8be41519c5b95 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -306,14 +306,14 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return KnownBits::add(Known1, Known2);
       },
       [](const APInt &N1, const APInt &N2) { return N1 + N2; },
-      /*CheckOptimality=*/false);
+      /*CheckOptimality=*/true);
   testBinaryOpExhaustive(
       "sub",
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::sub(Known1, Known2);
       },
       [](const APInt &N1, const APInt &N2) { return N1 - N2; },
-      /*CheckOptimality=*/false);
+      /*CheckOptimality=*/true);
   testBinaryOpExhaustive("umax", KnownBits::umax, APIntOps::umax);
   testBinaryOpExhaustive("umin", KnownBits::umin, APIntOps::umin);
   testBinaryOpExhaustive("smax", KnownBits::smax, APIntOps::smax);
@@ -385,26 +385,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       "sadd_sat", KnownBits::sadd_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.sadd_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "uadd_sat", KnownBits::uadd_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.uadd_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "ssub_sat", KnownBits::ssub_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.ssub_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "usub_sat", KnownBits::usub_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.usub_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "shl",
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -523,17 +519,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       /*CheckOptimality=*/false);
 
-  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
-                         false);
+  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
+                         APIntOps::avgFloorS);
 
-  testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU, APIntOps::avgFloorU,
-                         false);
+  testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
+                         APIntOps::avgFloorU);
 
-  testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU,
-                         false);
+  testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
 
-  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
-                         false);
+  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/110329


More information about the llvm-commits mailing list