[llvm] [KnownBits] Make `avg{Ceil,Floor}S` optimal (PR #110688)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 1 08:28:41 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-support
Author: Jay Foad (jayfoad)
<details>
<summary>Changes</summary>
Rewrite the signed functions in terms of the unsigned ones which are
already optimal.
---
Full diff: https://github.com/llvm/llvm-project/pull/110688.diff
3 Files Affected:
- (modified) llvm/include/llvm/Support/KnownBits.h (+3)
- (modified) llvm/lib/Support/KnownBits.cpp (+17-22)
- (modified) llvm/unittests/Support/KnownBitsTest.cpp (+3-4)
``````````diff
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index e4ec202f36aae0..a4b554fa2a0b72 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -29,6 +29,9 @@ struct KnownBits {
KnownBits(APInt Zero, APInt One)
: Zero(std::move(Zero)), One(std::move(One)) {}
+ // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+ static KnownBits flipSignBit(const KnownBits &Val);
+
public:
// Default construct Zero and One.
KnownBits() = default;
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 6863c5c0af5dca..a7801aa950cad3 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -18,6 +18,15 @@
using namespace llvm;
+KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
+ unsigned SignBitPosition = Val.getBitWidth() - 1;
+ APInt Zero = Val.Zero;
+ APInt One = Val.One;
+ Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+ One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+ return KnownBits(Zero, One);
+}
+
static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
bool CarryZero, bool CarryOne) {
@@ -200,16 +209,7 @@ KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
}
KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
- // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
- auto Flip = [](const KnownBits &Val) {
- unsigned SignBitPosition = Val.getBitWidth() - 1;
- APInt Zero = Val.Zero;
- APInt One = Val.One;
- Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
- One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
- return KnownBits(Zero, One);
- };
- return Flip(umax(Flip(LHS), Flip(RHS)));
+ return flipSignBit(umax(flipSignBit(LHS), flipSignBit(RHS)));
}
KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
@@ -763,11 +763,10 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
}
-static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
- bool IsSigned) {
+static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
unsigned BitWidth = LHS.getBitWidth();
- LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
- RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
+ LHS = LHS.zext(BitWidth + 1);
+ RHS = RHS.zext(BitWidth + 1);
LHS =
computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
LHS = LHS.extractBits(BitWidth, 1);
@@ -775,23 +774,19 @@ static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
}
KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ false,
- /* IsSigned */ true);
+ return flipSignBit(avgFloorU(flipSignBit(LHS), flipSignBit(RHS)));
}
KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ false,
- /* IsSigned */ false);
+ return avgComputeU(LHS, RHS, /* IsCeil */ false);
}
KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ true,
- /* IsSigned */ true);
+ return flipSignBit(avgCeilU(flipSignBit(LHS), flipSignBit(RHS)));
}
KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
- return avgCompute(LHS, RHS, /* IsCeil */ true,
- /* IsSigned */ false);
+ return avgComputeU(LHS, RHS, /* IsCeil */ true);
}
KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b701757aed5eb9..551c1a8107494b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
/*CheckOptimality=*/false);
- testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
- /*CheckOptimality=*/false);
+ testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
+ APIntOps::avgFloorS);
testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
APIntOps::avgFloorU);
testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
- testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
- /*CheckOptimality=*/false);
+ testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
}
TEST(KnownBitsTest, UnaryExhaustive) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/110688
More information about the llvm-commits
mailing list