[llvm] 5f50b18 - [KnownBits] Add implementations for saturating add/sub functions
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Tue May 23 11:57:03 PDT 2023
Author: Noah Goldstein
Date: 2023-05-23T13:52:40-05:00
New Revision: 5f50b180c50e5108b8b18d167147bef8c00fe532
URL: https://github.com/llvm/llvm-project/commit/5f50b180c50e5108b8b18d167147bef8c00fe532
DIFF: https://github.com/llvm/llvm-project/commit/5f50b180c50e5108b8b18d167147bef8c00fe532.diff
LOG: [KnownBits] Add implementations for saturating add/sub functions
These where previously missing. Even in the case where overflow is
indeterminate we can still deduce some of the low/high bits.
Reviewed By: RKSimon
Differential Revision: https://reviews.llvm.org/D150102
Added:
Modified:
llvm/include/llvm/Support/KnownBits.h
llvm/lib/Support/KnownBits.cpp
llvm/unittests/Support/KnownBitsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index a997d8d49a915..6ba4dd4f82540 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -332,6 +332,18 @@ struct KnownBits {
static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
KnownBits RHS);
+ /// Compute knownbits resulting from llvm.sadd.sat(LHS, RHS)
+ static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS);
+
+ /// Compute knownbits resulting from llvm.uadd.sat(LHS, RHS)
+ static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS);
+
+ /// Compute knownbits resulting from llvm.ssub.sat(LHS, RHS)
+ static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS);
+
+ /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
+ static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
+
/// Compute known bits resulting from multiplying LHS and RHS.
static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS,
bool NoUndefSelfMultiply = false);
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 9d6238fc4b0b8..1a5f1ad4cf03e 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -465,6 +465,171 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
return KnownAbs;
}
+static KnownBits computeForSatAddSub(bool Add, bool Signed,
+ const KnownBits &LHS,
+ const KnownBits &RHS) {
+ assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
+ // We don't see NSW even for sadd/ssub as we want to check if the result has
+ // signed overflow.
+ KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
+ unsigned BitWidth = Res.getBitWidth();
+ auto SignBitKnown = [&](const KnownBits &K) {
+ return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
+ };
+ std::optional<bool> Overflow;
+
+ if (Signed) {
+ // If we can actually detect overflow do so. Otherwise leave Overflow as
+ // nullopt (we assume it may have happened).
+ if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
+ if (Add) {
+ // sadd.sat
+ Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
+ Res.isNonNegative() != LHS.isNonNegative());
+ } else {
+ // ssub.sat
+ Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
+ Res.isNonNegative() != LHS.isNonNegative());
+ }
+ }
+ } else if (Add) {
+ // uadd.sat
+ Overflow = KnownBits::ult(Res, RHS);
+ if (!Overflow)
+ Overflow = KnownBits::ult(Res, LHS);
+ if (!Overflow) {
+ bool Of;
+ (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
+ if (Of)
+ Overflow = true;
+ (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
+ if (!Of)
+ Overflow = false;
+ }
+ } else {
+ // usub.sat
+ Overflow = KnownBits::ugt(Res, LHS);
+ if (!Overflow) {
+ bool Of;
+ (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
+ if (Of)
+ Overflow = Of;
+ (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
+ if (!Of)
+ Overflow = Of;
+ }
+ }
+
+ if (Signed) {
+ if (Add) {
+ if (LHS.isNonNegative() && RHS.isNonNegative()) {
+ // Pos + Pos -> Pos
+ Res.One.clearSignBit();
+ Res.Zero.setSignBit();
+ }
+ if (LHS.isNegative() && RHS.isNegative()) {
+ // Neg + Neg -> Neg
+ Res.One.setSignBit();
+ Res.Zero.clearSignBit();
+ }
+ } else {
+ if (LHS.isNegative() && RHS.isNonNegative()) {
+ // Neg - Pos -> Neg
+ Res.One.setSignBit();
+ Res.Zero.clearSignBit();
+ } else if (LHS.isNonNegative() && RHS.isNegative()) {
+ // Pos - Neg -> Pos
+ Res.One.clearSignBit();
+ Res.Zero.setSignBit();
+ }
+ }
+ } else {
+ // Add: Leading ones of either operand are preserved.
+ // Sub: Leading zeros of LHS and leading ones of RHS are preserved
+ // as leading zeros in the result.
+ unsigned LeadingKnown;
+ if (Add)
+ LeadingKnown =
+ std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
+ else
+ LeadingKnown =
+ std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
+
+ // We select between the operation result and all-ones/zero
+ // respectively, so we can preserve known ones/zeros.
+ APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
+ if (Add) {
+ Res.One |= Mask;
+ Res.Zero &= ~Mask;
+ } else {
+ Res.Zero |= Mask;
+ Res.One &= ~Mask;
+ }
+ }
+
+ if (Overflow) {
+ // We know whether or not we overflowed.
+ if (!(*Overflow)) {
+ // No overflow.
+ assert(!Res.hasConflict() && "Bad Output");
+ return Res;
+ }
+
+ // We overflowed
+ APInt C;
+ if (Signed) {
+ // sadd.sat / ssub.sat
+ assert(SignBitKnown(LHS) &&
+ "We somehow know overflow without knowing input sign");
+ C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
+ : APInt::getSignedMaxValue(BitWidth);
+ } else if (Add) {
+ // uadd.sat
+ C = APInt::getMaxValue(BitWidth);
+ } else {
+ // uadd.sat
+ C = APInt::getMinValue(BitWidth);
+ }
+
+ Res.One = C;
+ Res.Zero = ~C;
+ assert(!Res.hasConflict() && "Bad Output");
+ return Res;
+ }
+
+ // We don't know if we overflowed.
+ if (Signed) {
+ // sadd.sat/ssub.sat
+ // We can keep our information about the sign bits.
+ Res.Zero.clearLowBits(BitWidth - 1);
+ Res.One.clearLowBits(BitWidth - 1);
+ } else if (Add) {
+ // uadd.sat
+ // We need to clear all the known zeros as we can only use the leading ones.
+ Res.Zero.clearAllBits();
+ } else {
+ // usub.sat
+ // We need to clear all the known ones as we can only use the leading zero.
+ Res.One.clearAllBits();
+ }
+
+ assert(!Res.hasConflict() && "Bad Output");
+ return Res;
+}
+
+KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
+ return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
+}
+KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
+ return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
+}
+KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
+ return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
+}
+KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
+ return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
+}
+
KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
bool NoUndefSelfMultiply) {
unsigned BitWidth = LHS.getBitWidth();
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 14c582506e6dc..457b7c8a5af14 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -300,7 +300,38 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.srem(N2);
},
checkCorrectnessOnlyBinary);
-
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::sadd_sat(Known1, Known2);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ return N1.sadd_sat(N2);
+ },
+ checkCorrectnessOnlyBinary);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::uadd_sat(Known1, Known2);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ return N1.uadd_sat(N2);
+ },
+ checkCorrectnessOnlyBinary);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::ssub_sat(Known1, Known2);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ return N1.ssub_sat(N2);
+ },
+ checkCorrectnessOnlyBinary);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::usub_sat(Known1, Known2);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ return N1.usub_sat(N2);
+ },
+ checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::shl(Known1, Known2);
More information about the llvm-commits
mailing list