[llvm] 5f50b18 - [KnownBits] Add implementations for saturating add/sub functions

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Tue May 23 11:57:03 PDT 2023


Author: Noah Goldstein
Date: 2023-05-23T13:52:40-05:00
New Revision: 5f50b180c50e5108b8b18d167147bef8c00fe532

URL: https://github.com/llvm/llvm-project/commit/5f50b180c50e5108b8b18d167147bef8c00fe532
DIFF: https://github.com/llvm/llvm-project/commit/5f50b180c50e5108b8b18d167147bef8c00fe532.diff

LOG: [KnownBits] Add implementations for saturating add/sub functions

These where previously missing. Even in the case where overflow is
indeterminate we can still deduce some of the low/high bits.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D150102

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Support/KnownBits.cpp
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index a997d8d49a915..6ba4dd4f82540 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -332,6 +332,18 @@ struct KnownBits {
   static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
                                     KnownBits RHS);
 
+  /// Compute knownbits resulting from llvm.sadd.sat(LHS, RHS)
+  static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute knownbits resulting from llvm.uadd.sat(LHS, RHS)
+  static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute knownbits resulting from llvm.ssub.sat(LHS, RHS)
+  static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
+  static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
+
   /// Compute known bits resulting from multiplying LHS and RHS.
   static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS,
                        bool NoUndefSelfMultiply = false);

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 9d6238fc4b0b8..1a5f1ad4cf03e 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -465,6 +465,171 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
   return KnownAbs;
 }
 
+static KnownBits computeForSatAddSub(bool Add, bool Signed,
+                                     const KnownBits &LHS,
+                                     const KnownBits &RHS) {
+  assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
+  // We don't see NSW even for sadd/ssub as we want to check if the result has
+  // signed overflow.
+  KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
+  unsigned BitWidth = Res.getBitWidth();
+  auto SignBitKnown = [&](const KnownBits &K) {
+    return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
+  };
+  std::optional<bool> Overflow;
+
+  if (Signed) {
+    // If we can actually detect overflow do so. Otherwise leave Overflow as
+    // nullopt (we assume it may have happened).
+    if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
+      if (Add) {
+        // sadd.sat
+        Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
+                    Res.isNonNegative() != LHS.isNonNegative());
+      } else {
+        // ssub.sat
+        Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
+                    Res.isNonNegative() != LHS.isNonNegative());
+      }
+    }
+  } else if (Add) {
+    // uadd.sat
+    Overflow = KnownBits::ult(Res, RHS);
+    if (!Overflow)
+      Overflow = KnownBits::ult(Res, LHS);
+    if (!Overflow) {
+      bool Of;
+      (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
+      if (Of)
+        Overflow = true;
+      (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
+      if (!Of)
+        Overflow = false;
+    }
+  } else {
+    // usub.sat
+    Overflow = KnownBits::ugt(Res, LHS);
+    if (!Overflow) {
+      bool Of;
+      (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
+      if (Of)
+        Overflow = Of;
+      (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
+      if (!Of)
+        Overflow = Of;
+    }
+  }
+
+  if (Signed) {
+    if (Add) {
+      if (LHS.isNonNegative() && RHS.isNonNegative()) {
+        // Pos + Pos -> Pos
+        Res.One.clearSignBit();
+        Res.Zero.setSignBit();
+      }
+      if (LHS.isNegative() && RHS.isNegative()) {
+        // Neg + Neg -> Neg
+        Res.One.setSignBit();
+        Res.Zero.clearSignBit();
+      }
+    } else {
+      if (LHS.isNegative() && RHS.isNonNegative()) {
+        // Neg - Pos -> Neg
+        Res.One.setSignBit();
+        Res.Zero.clearSignBit();
+      } else if (LHS.isNonNegative() && RHS.isNegative()) {
+        // Pos - Neg -> Pos
+        Res.One.clearSignBit();
+        Res.Zero.setSignBit();
+      }
+    }
+  } else {
+    // Add: Leading ones of either operand are preserved.
+    // Sub: Leading zeros of LHS and leading ones of RHS are preserved
+    // as leading zeros in the result.
+    unsigned LeadingKnown;
+    if (Add)
+      LeadingKnown =
+          std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
+    else
+      LeadingKnown =
+          std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
+
+    // We select between the operation result and all-ones/zero
+    // respectively, so we can preserve known ones/zeros.
+    APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
+    if (Add) {
+      Res.One |= Mask;
+      Res.Zero &= ~Mask;
+    } else {
+      Res.Zero |= Mask;
+      Res.One &= ~Mask;
+    }
+  }
+
+  if (Overflow) {
+    // We know whether or not we overflowed.
+    if (!(*Overflow)) {
+      // No overflow.
+      assert(!Res.hasConflict() && "Bad Output");
+      return Res;
+    }
+
+    // We overflowed
+    APInt C;
+    if (Signed) {
+      // sadd.sat / ssub.sat
+      assert(SignBitKnown(LHS) &&
+             "We somehow know overflow without knowing input sign");
+      C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
+                           : APInt::getSignedMaxValue(BitWidth);
+    } else if (Add) {
+      // uadd.sat
+      C = APInt::getMaxValue(BitWidth);
+    } else {
+      // uadd.sat
+      C = APInt::getMinValue(BitWidth);
+    }
+
+    Res.One = C;
+    Res.Zero = ~C;
+    assert(!Res.hasConflict() && "Bad Output");
+    return Res;
+  }
+
+  // We don't know if we overflowed.
+  if (Signed) {
+    // sadd.sat/ssub.sat
+    // We can keep our information about the sign bits.
+    Res.Zero.clearLowBits(BitWidth - 1);
+    Res.One.clearLowBits(BitWidth - 1);
+  } else if (Add) {
+    // uadd.sat
+    // We need to clear all the known zeros as we can only use the leading ones.
+    Res.Zero.clearAllBits();
+  } else {
+    // usub.sat
+    // We need to clear all the known ones as we can only use the leading zero.
+    Res.One.clearAllBits();
+  }
+
+  assert(!Res.hasConflict() && "Bad Output");
+  return Res;
+}
+
+KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
+  return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
+}
+KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
+  return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
+}
+KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
+  return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
+}
+KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
+  return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
+}
+
 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
                          bool NoUndefSelfMultiply) {
   unsigned BitWidth = LHS.getBitWidth();

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 14c582506e6dc..457b7c8a5af14 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -300,7 +300,38 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.srem(N2);
       },
       checkCorrectnessOnlyBinary);
-
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::sadd_sat(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        return N1.sadd_sat(N2);
+      },
+      checkCorrectnessOnlyBinary);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::uadd_sat(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        return N1.uadd_sat(N2);
+      },
+      checkCorrectnessOnlyBinary);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::ssub_sat(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        return N1.ssub_sat(N2);
+      },
+      checkCorrectnessOnlyBinary);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::usub_sat(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        return N1.usub_sat(N2);
+      },
+      checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::shl(Known1, Known2);


        


More information about the llvm-commits mailing list