[llvm] 5cabf15 - [KnownBits] Make `avg{Ceil, Floor}S` optimal (#110688)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 1 11:34:53 PDT 2024


Author: Jay Foad
Date: 2024-10-01T19:34:50+01:00
New Revision: 5cabf1505d2180083ae3cf34abd79924cbcdbdbf

URL: https://github.com/llvm/llvm-project/commit/5cabf1505d2180083ae3cf34abd79924cbcdbdbf
DIFF: https://github.com/llvm/llvm-project/commit/5cabf1505d2180083ae3cf34abd79924cbcdbdbf.diff

LOG: [KnownBits] Make `avg{Ceil,Floor}S` optimal (#110688)

Rewrite the signed functions in terms of the unsigned ones which are
already optimal.

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Support/KnownBits.cpp
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index e4ec202f36aae0..a4b554fa2a0b72 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -29,6 +29,9 @@ struct KnownBits {
   KnownBits(APInt Zero, APInt One)
       : Zero(std::move(Zero)), One(std::move(One)) {}
 
+  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+  static KnownBits flipSignBit(const KnownBits &Val);
+
 public:
   // Default construct Zero and One.
   KnownBits() = default;

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 6863c5c0af5dca..89668af378070b 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -18,6 +18,15 @@
 
 using namespace llvm;
 
+KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
+  unsigned SignBitPosition = Val.getBitWidth() - 1;
+  APInt Zero = Val.Zero;
+  APInt One = Val.One;
+  Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+  One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+  return KnownBits(Zero, One);
+}
+
 static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
                                     bool CarryZero, bool CarryOne) {
 
@@ -200,16 +209,7 @@ KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
 }
 
 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
-  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
-  auto Flip = [](const KnownBits &Val) {
-    unsigned SignBitPosition = Val.getBitWidth() - 1;
-    APInt Zero = Val.Zero;
-    APInt One = Val.One;
-    Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
-    One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
-    return KnownBits(Zero, One);
-  };
-  return Flip(umax(Flip(LHS), Flip(RHS)));
+  return flipSignBit(umax(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
@@ -763,11 +763,10 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
 }
 
-static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
-                            bool IsSigned) {
+static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
   unsigned BitWidth = LHS.getBitWidth();
-  LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
-  RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
+  LHS = LHS.zext(BitWidth + 1);
+  RHS = RHS.zext(BitWidth + 1);
   LHS =
       computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
   LHS = LHS.extractBits(BitWidth, 1);
@@ -775,23 +774,19 @@ static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
 }
 
 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ false,
-                    /* IsSigned */ true);
+  return flipSignBit(avgFloorU(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ false,
-                    /* IsSigned */ false);
+  return avgComputeU(LHS, RHS, /*IsCeil=*/false);
 }
 
 KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ true,
-                    /* IsSigned */ true);
+  return flipSignBit(avgCeilU(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ true,
-                    /* IsSigned */ false);
+  return avgComputeU(LHS, RHS, /*IsCeil=*/true);
 }
 
 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b701757aed5eb9..551c1a8107494b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       /*CheckOptimality=*/false);
 
-  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
-                         /*CheckOptimality=*/false);
+  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
+                         APIntOps::avgFloorS);
 
   testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
                          APIntOps::avgFloorU);
 
   testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
 
-  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
-                         /*CheckOptimality=*/false);
+  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {


        


More information about the llvm-commits mailing list