[llvm] [KnownBits] Make `{s,u}{add,sub}_sat` optimal (PR #113096)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 30 10:50:09 PDT 2024
https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/113096
>From a8614215d66b1f5bb2af2364c14005f391b4d4fd Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 27 Sep 2024 15:38:54 -0500
Subject: [PATCH] [KnownBits] Make `{s,u}{add,sub}_sat` optimal
Changes are:
1) Make signed-overflow detection optimal
2) For signed-overflow, try to rule out direction even if we can't
totally rule out overflow.
3) Intersect add/sub assuming no overflow with possible overflow
clamping values as opposed to add/sub without the assumption.
---
llvm/lib/Support/KnownBits.cpp | 138 +++++++++---------
.../ValueTracking/knownbits-sat-addsub.ll | 9 +-
llvm/unittests/Support/KnownBitsTest.cpp | 12 +-
3 files changed, 78 insertions(+), 81 deletions(-)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 89668af378070b..22a1628b0fa23a 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -610,28 +610,78 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
const KnownBits &RHS) {
// We don't see NSW even for sadd/ssub as we want to check if the result has
// signed overflow.
- KnownBits Res =
- KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
- unsigned BitWidth = Res.getBitWidth();
- auto SignBitKnown = [&](const KnownBits &K) {
- return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
- };
- std::optional<bool> Overflow;
+ unsigned BitWidth = LHS.getBitWidth();
+ std::optional<bool> Overflow;
+ // Even if we can't entirely rule out overflow, we may be able to rule out
+ // overflow in one direction. This allows us to potentially keep some of the
+ // add/sub bits. I.e if we can't overflow in the positive direction we won't
+ // clamp to INT_MAX so we can keep low 0s from the add/sub result.
+ bool MayNegClamp = true;
+ bool MayPosClamp = true;
if (Signed) {
- // If we can actually detect overflow do so. Otherwise leave Overflow as
- // nullopt (we assume it may have happened).
- if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
+ // Easy cases we can rule out any overflow.
+ if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
+ (LHS.isNonNegative() && RHS.isNegative())))
+ Overflow = false;
+ else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
+ (LHS.isNonNegative() && RHS.isNonNegative()))))
+ Overflow = false;
+ else {
+ // Check if we may overflow. If we can't rule out overflow then check if
+ // we can rule out a direction at least.
+ KnownBits UnsignedLHS = LHS;
+ KnownBits UnsignedRHS = RHS;
+ UnsignedLHS.One.clearSignBit();
+ UnsignedLHS.Zero.setSignBit();
+ UnsignedRHS.One.clearSignBit();
+ UnsignedRHS.Zero.setSignBit();
+ KnownBits Res =
+ KnownBits::computeForAddSub(Add, /*NSW=*/false,
+ /*NUW=*/false, UnsignedLHS, UnsignedRHS);
if (Add) {
- // sadd.sat
- Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
- Res.isNonNegative() != LHS.isNonNegative());
+ if (Res.isNegative()) {
+ // Only overflow scenario is Pos + Pos.
+ MayNegClamp = false;
+ // Pos + Pos will overflow with extra signbit.
+ if (LHS.isNonNegative() && RHS.isNonNegative())
+ Overflow = true;
+ } else if (Res.isNonNegative()) {
+ // Only overflow scenario is Neg + Neg
+ MayPosClamp = false;
+ // Neg + Neg will overflow without extra signbit.
+ if (LHS.isNegative() && RHS.isNegative())
+ Overflow = true;
+ }
+ // We will never clamp to the opposite sign of N-bit result.
+ if (LHS.isNegative() || RHS.isNegative())
+ MayPosClamp = false;
+ if (LHS.isNonNegative() || RHS.isNonNegative())
+ MayNegClamp = false;
} else {
- // ssub.sat
- Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
- Res.isNonNegative() != LHS.isNonNegative());
+ if (Res.isNegative()) {
+ // Only overflow scenario is Neg - Pos.
+ MayPosClamp = false;
+ // Neg - Pos will overflow with extra signbit.
+ if (LHS.isNegative() && RHS.isNonNegative())
+ Overflow = true;
+ } else if (Res.isNonNegative()) {
+ // Only overflow scenario is Pos - Neg.
+ MayNegClamp = false;
+ // Pos - Neg will overflow without extra signbit.
+ if (LHS.isNonNegative() && RHS.isNegative())
+ Overflow = true;
+ }
+ // We will never clamp to the opposite sign of N-bit result.
+ if (LHS.isNegative() || RHS.isNonNegative())
+ MayPosClamp = false;
+ if (LHS.isNonNegative() || RHS.isNegative())
+ MayNegClamp = false;
}
}
+ // If we have ruled out all clamping, we will never overflow.
+ if (!MayNegClamp && !MayPosClamp)
+ Overflow = false;
} else if (Add) {
// uadd.sat
bool Of;
@@ -656,52 +706,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
}
}
- if (Signed) {
- if (Add) {
- if (LHS.isNonNegative() && RHS.isNonNegative()) {
- // Pos + Pos -> Pos
- Res.One.clearSignBit();
- Res.Zero.setSignBit();
- }
- if (LHS.isNegative() && RHS.isNegative()) {
- // Neg + Neg -> Neg
- Res.One.setSignBit();
- Res.Zero.clearSignBit();
- }
- } else {
- if (LHS.isNegative() && RHS.isNonNegative()) {
- // Neg - Pos -> Neg
- Res.One.setSignBit();
- Res.Zero.clearSignBit();
- } else if (LHS.isNonNegative() && RHS.isNegative()) {
- // Pos - Neg -> Pos
- Res.One.clearSignBit();
- Res.Zero.setSignBit();
- }
- }
- } else {
- // Add: Leading ones of either operand are preserved.
- // Sub: Leading zeros of LHS and leading ones of RHS are preserved
- // as leading zeros in the result.
- unsigned LeadingKnown;
- if (Add)
- LeadingKnown =
- std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
- else
- LeadingKnown =
- std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
-
- // We select between the operation result and all-ones/zero
- // respectively, so we can preserve known ones/zeros.
- APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
- if (Add) {
- Res.One |= Mask;
- Res.Zero &= ~Mask;
- } else {
- Res.Zero |= Mask;
- Res.One &= ~Mask;
- }
- }
+ KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
+ /*NUW=*/!Signed, LHS, RHS);
if (Overflow) {
// We know whether or not we overflowed.
@@ -714,7 +720,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
APInt C;
if (Signed) {
// sadd.sat / ssub.sat
- assert(SignBitKnown(LHS) &&
+ assert(!LHS.isSignUnknown() &&
"We somehow know overflow without knowing input sign");
C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
: APInt::getSignedMaxValue(BitWidth);
@@ -735,8 +741,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
if (Signed) {
// sadd.sat/ssub.sat
// We can keep our information about the sign bits.
- Res.Zero.clearLowBits(BitWidth - 1);
- Res.One.clearLowBits(BitWidth - 1);
+ if (MayPosClamp)
+ Res.Zero.clearLowBits(BitWidth - 1);
+ if (MayNegClamp)
+ Res.One.clearLowBits(BitWidth - 1);
} else if (Add) {
// uadd.sat
// We need to clear all the known zeros as we can only use the leading ones.
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
index c2926eaffa58c5..f9618e1ddbc022 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
; CHECK-LABEL: @ssub_sat_fail_may_overflow(
-; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], 15
-; CHECK-NEXT: [[YY:%.*]] = and i8 [[Y:%.*]], 15
-; CHECK-NEXT: [[LHS:%.*]] = or i8 [[XX]], 1
-; CHECK-NEXT: [[RHS:%.*]] = and i8 [[YY]], -2
-; CHECK-NEXT: [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[EXP]], 1
-; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT: ret i1 [[R]]
+; CHECK-NEXT: ret i1 false
;
%xx = and i8 %x, 15
%yy = and i8 %y, 15
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b16368de176481..ce0bf86e39dd7b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
"sadd_sat", KnownBits::sadd_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.sadd_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"uadd_sat", KnownBits::uadd_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.uadd_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"ssub_sat", KnownBits::ssub_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.ssub_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"usub_sat", KnownBits::usub_sat,
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
return N1.usub_sat(N2);
- },
- /*CheckOptimality=*/false);
+ });
testBinaryOpExhaustive(
"shl",
[](const KnownBits &Known1, const KnownBits &Known2) {
More information about the llvm-commits
mailing list