[llvm] [KnownBits] Make `avg{Ceil,Floor}S` optimal (PR #110688)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 1 08:28:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-support

Author: Jay Foad (jayfoad)

<details>
<summary>Changes</summary>

Rewrite the signed functions in terms of the unsigned ones which are
already optimal.


---
Full diff: https://github.com/llvm/llvm-project/pull/110688.diff


3 Files Affected:

- (modified) llvm/include/llvm/Support/KnownBits.h (+3) 
- (modified) llvm/lib/Support/KnownBits.cpp (+17-22) 
- (modified) llvm/unittests/Support/KnownBitsTest.cpp (+3-4) 


``````````diff
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index e4ec202f36aae0..a4b554fa2a0b72 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -29,6 +29,9 @@ struct KnownBits {
   KnownBits(APInt Zero, APInt One)
       : Zero(std::move(Zero)), One(std::move(One)) {}
 
+  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+  static KnownBits flipSignBit(const KnownBits &Val);
+
 public:
   // Default construct Zero and One.
   KnownBits() = default;
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 6863c5c0af5dca..a7801aa950cad3 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -18,6 +18,15 @@
 
 using namespace llvm;
 
+KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
+  unsigned SignBitPosition = Val.getBitWidth() - 1;
+  APInt Zero = Val.Zero;
+  APInt One = Val.One;
+  Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+  One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+  return KnownBits(Zero, One);
+}
+
 static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
                                     bool CarryZero, bool CarryOne) {
 
@@ -200,16 +209,7 @@ KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
 }
 
 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
-  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
-  auto Flip = [](const KnownBits &Val) {
-    unsigned SignBitPosition = Val.getBitWidth() - 1;
-    APInt Zero = Val.Zero;
-    APInt One = Val.One;
-    Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
-    One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
-    return KnownBits(Zero, One);
-  };
-  return Flip(umax(Flip(LHS), Flip(RHS)));
+  return flipSignBit(umax(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
@@ -763,11 +763,10 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
 }
 
-static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
-                            bool IsSigned) {
+static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
   unsigned BitWidth = LHS.getBitWidth();
-  LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
-  RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
+  LHS = LHS.zext(BitWidth + 1);
+  RHS = RHS.zext(BitWidth + 1);
   LHS =
       computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
   LHS = LHS.extractBits(BitWidth, 1);
@@ -775,23 +774,19 @@ static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
 }
 
 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ false,
-                    /* IsSigned */ true);
+  return flipSignBit(avgFloorU(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ false,
-                    /* IsSigned */ false);
+  return avgComputeU(LHS, RHS, /* IsCeil */ false);
 }
 
 KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ true,
-                    /* IsSigned */ true);
+  return flipSignBit(avgCeilU(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ true,
-                    /* IsSigned */ false);
+  return avgComputeU(LHS, RHS, /* IsCeil */ true);
 }
 
 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b701757aed5eb9..551c1a8107494b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       /*CheckOptimality=*/false);
 
-  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
-                         /*CheckOptimality=*/false);
+  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
+                         APIntOps::avgFloorS);
 
   testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
                          APIntOps::avgFloorU);
 
   testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
 
-  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
-                         /*CheckOptimality=*/false);
+  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/110688


More information about the llvm-commits mailing list