[llvm] [KnownBits] Make `avg{Ceil,Floor}S` and `{s,u}{add,sub}_sat` optimal (PR #110329)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 30 11:45:14 PDT 2024
https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/110329
>From d037382b86f0fd05d165657857cc76a3a4a28615 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 27 Sep 2024 15:39:03 -0500
Subject: [PATCH] [KnownBits] Make `avg{Ceil,Floor}S` optimal
All we where missing was the signbit if we knew the incoming signbit
of either LHS or RHS.
Since the base addition in the average is with an extra bit width it
cannot overflow, we figure out the result sign based on the magnitude
of the input. If the negative component has a larger magnitude the
result is negative and vice versa for the positive case.
---
llvm/lib/Support/KnownBits.cpp | 57 ++++++++++++++++++------
llvm/unittests/Support/KnownBitsTest.cpp | 7 ++-
2 files changed, 46 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 6863c5c0af5dca..5b5e6df53d6a16 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -766,32 +766,61 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
bool IsSigned) {
unsigned BitWidth = LHS.getBitWidth();
- LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
- RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
- LHS =
- computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
- LHS = LHS.extractBits(BitWidth, 1);
- return LHS;
+ KnownBits ExtLHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
+ KnownBits ExtRHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
+ KnownBits Res = computeForAddCarry(ExtLHS, ExtRHS, /*CarryZero=*/!IsCeil,
+ /*CarryOne=*/IsCeil);
+ Res = Res.extractBits(BitWidth, 1);
+
+ // If we have only 1 known signbit between LHS/RHS we can try to figure
+ // out result signbit.
+ // NB: If we know both signbits `computeForAddCarry` gets the optimal result
+ // already.
+ if (IsSigned && Res.isSignUnknown() &&
+ LHS.isSignUnknown() != RHS.isSignUnknown()) {
+ if (LHS.isSignUnknown())
+ std::swap(LHS, RHS);
+ KnownBits UnsignedLHS = LHS;
+ KnownBits UnsignedRHS = RHS;
+ UnsignedLHS.One.clearSignBit();
+ UnsignedLHS.Zero.setSignBit();
+ UnsignedRHS.One.clearSignBit();
+ UnsignedRHS.Zero.setSignBit();
+ KnownBits ResOf =
+ computeForAddCarry(UnsignedLHS, UnsignedRHS, /*CarryZero=*/!IsCeil,
+ /*CarryOne=*/IsCeil);
+ // Assuming no overflow (which is the case since we extend the addition when
+ // taking the average):
+ // Neg + Neg -> Neg
+ // Neg + Pos -> Neg if the signbit doesn't overflow
+ if (LHS.isNegative() && ResOf.isNonNegative())
+ Res.makeNegative();
+ // Pos + Pos -> Pos
+ // Pos + Neg -> Pos if the signbit does overflow
+ else if (LHS.isNonNegative() && ResOf.isNegative())
+ Res.makeNonNegative();
+ }
+ return Res;
}
KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ false,
- /* IsSigned */ true);
+ return avgCompute(LHS, RHS, /* IsCeil=*/false,
+ /* IsSigned=*/true);
}
KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ false,
- /* IsSigned */ false);
+ return avgCompute(LHS, RHS, /* IsCeil=*/false,
+ /* IsSigned=*/false);
}
KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ true,
- /* IsSigned */ true);
+ return avgCompute(LHS, RHS, /* IsCeil=*/true,
+ /* IsSigned=*/true);
}
KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ true,
- /* IsSigned */ false);
+ return avgCompute(LHS, RHS, /* IsCeil=*/true,
+ /* IsSigned=*/false);
}
KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b701757aed5eb9..551c1a8107494b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
/*CheckOptimality=*/false);
- testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
- /*CheckOptimality=*/false);
+ testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
+ APIntOps::avgFloorS);
testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
APIntOps::avgFloorU);
testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
- testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
- /*CheckOptimality=*/false);
+ testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
}
TEST(KnownBitsTest, UnaryExhaustive) {
More information about the llvm-commits
mailing list