[llvm] aff0570 - [ADT] Add implementations for avgFloor and avgCeil to APInt (#84431)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 14 03:00:13 PDT 2024


Author: Atousa Duprat
Date: 2024-03-14T10:00:08Z
New Revision: aff05708916107eec73ea5db363f625926f60730

URL: https://github.com/llvm/llvm-project/commit/aff05708916107eec73ea5db363f625926f60730
DIFF: https://github.com/llvm/llvm-project/commit/aff05708916107eec73ea5db363f625926f60730.diff

LOG: [ADT] Add implementations for avgFloor and avgCeil to APInt (#84431)

Supports both signed and unsigned expansions.

SelectionDAG now calls the APInt implementation of these functions.

Fixes #84211.

Added: 
    

Modified: 
    llvm/include/llvm/ADT/APInt.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Support/APInt.cpp
    llvm/unittests/ADT/APIntTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index bea3e28adf308f..b10e2107b794a3 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2198,6 +2198,18 @@ inline const APInt abdu(const APInt &A, const APInt &B) {
   return A.uge(B) ? (A - B) : (B - A);
 }
 
+/// Compute the floor of the signed average of C1 and C2
+APInt avgFloorS(const APInt &C1, const APInt &C2);
+
+/// Compute the floor of the unsigned average of C1 and C2
+APInt avgFloorU(const APInt &C1, const APInt &C2);
+
+/// Compute the ceil of the signed average of C1 and C2
+APInt avgCeilS(const APInt &C1, const APInt &C2);
+
+/// Compute the ceil of the unsigned average of C1 and C2
+APInt avgCeilU(const APInt &C1, const APInt &C2);
+
 /// Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b8c7d08da3e2da..19f9354c23da42 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6039,30 +6039,14 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     APInt C2Ext = C2.zext(FullWidth);
     return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
   }
-  case ISD::AVGFLOORS: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.sext(FullWidth);
-    APInt C2Ext = C2.sext(FullWidth);
-    return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
-  }
-  case ISD::AVGFLOORU: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.zext(FullWidth);
-    APInt C2Ext = C2.zext(FullWidth);
-    return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
-  }
-  case ISD::AVGCEILS: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.sext(FullWidth);
-    APInt C2Ext = C2.sext(FullWidth);
-    return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
-  }
-  case ISD::AVGCEILU: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.zext(FullWidth);
-    APInt C2Ext = C2.zext(FullWidth);
-    return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
-  }
+  case ISD::AVGFLOORS:
+    return APIntOps::avgFloorS(C1, C2);
+  case ISD::AVGFLOORU:
+    return APIntOps::avgFloorU(C1, C2);
+  case ISD::AVGCEILS:
+    return APIntOps::avgCeilS(C1, C2);
+  case ISD::AVGCEILU:
+    return APIntOps::avgCeilU(C1, C2);
   case ISD::ABDS:
     return APIntOps::abds(C1, C2);
   case ISD::ABDU:

diff  --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..7053f3b87682f8 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3094,3 +3094,39 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
     memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
   }
 }
+
+APInt APIntOps::avgFloorS(const APInt &C1, const APInt &C2) {
+  // Return floor((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.sext(FullWidth);
+  APInt C2Ext = C2.sext(FullWidth);
+  return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
+}
+
+APInt APIntOps::avgFloorU(const APInt &C1, const APInt &C2) {
+  // Return floor((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.zext(FullWidth);
+  APInt C2Ext = C2.zext(FullWidth);
+  return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
+}
+
+APInt APIntOps::avgCeilS(const APInt &C1, const APInt &C2) {
+  // Return ceil((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.sext(FullWidth);
+  APInt C2Ext = C2.sext(FullWidth);
+  return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
+}
+
+APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
+  // Return ceil((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.zext(FullWidth);
+  APInt C2Ext = C2.zext(FullWidth);
+  return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
+}

diff  --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 11237d2e1602dc..400313a6234677 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -14,6 +14,7 @@
 #include "llvm/Support/Alignment.h"
 #include "gtest/gtest.h"
 #include <array>
+#include <climits>
 #include <optional>
 
 using namespace llvm;
@@ -2911,6 +2912,91 @@ TEST(APIntTest, RoundingSDiv) {
   }
 }
 
+TEST(APIntTest, Average) {
+  APInt A0(32, 0);
+  APInt A2(32, 2);
+  APInt A100(32, 100);
+  APInt A101(32, 101);
+  APInt A200(32, 200, false);
+  APInt ApUMax(32, UINT_MAX, false);
+
+  EXPECT_EQ(APInt(32, 150), APIntOps::avgFloorU(A100, A200));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorU(A100, A200));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilU(A100, A200));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorU(A100, A101));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilU(A100, A101));
+  EXPECT_EQ(A0, APIntOps::avgFloorU(A0, A0));
+  EXPECT_EQ(A0, APIntOps::avgCeilU(A0, A0));
+  EXPECT_EQ(ApUMax, APIntOps::avgFloorU(ApUMax, ApUMax));
+  EXPECT_EQ(ApUMax, APIntOps::avgCeilU(ApUMax, ApUMax));
+  EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorU(A0, ApUMax));
+  EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilU(A0, ApUMax));
+
+  APInt Ap100(32, +100);
+  APInt Ap101(32, +101);
+  APInt Ap200(32, +200);
+  APInt Am1(32, -1);
+  APInt Am100(32, -100);
+  APInt Am101(32, -101);
+  APInt Am200(32, -200);
+  APInt AmSMin(32, INT_MIN);
+  APInt ApSMax(32, INT_MAX);
+
+  EXPECT_EQ(APInt(32, +150), APIntOps::avgFloorS(Ap100, Ap200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Ap100, Ap200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Ap100, Ap200));
+
+  EXPECT_EQ(APInt(32, -150), APIntOps::avgFloorS(Am100, Am200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Am100, Am200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Am100, Am200));
+
+  EXPECT_EQ(APInt(32, +100), APIntOps::avgFloorS(Ap100, Ap101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Ap100, Ap101));
+  EXPECT_EQ(APInt(32, +101), APIntOps::avgCeilS(Ap100, Ap101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Ap100, Ap101));
+
+  EXPECT_EQ(APInt(32, -101), APIntOps::avgFloorS(Am100, Am101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Am100, Am101));
+  EXPECT_EQ(APInt(32, -100), APIntOps::avgCeilS(Am100, Am101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Am100, Am101));
+
+  EXPECT_EQ(AmSMin, APIntOps::avgFloorS(AmSMin, AmSMin));
+  EXPECT_EQ(AmSMin, APIntOps::avgCeilS(AmSMin, AmSMin));
+
+  EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(A0, AmSMin));
+  EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(A0, AmSMin));
+
+  EXPECT_EQ(A0, APIntOps::avgFloorS(A0, A0));
+  EXPECT_EQ(A0, APIntOps::avgCeilS(A0, A0));
+
+  EXPECT_EQ(Am1, APIntOps::avgFloorS(AmSMin, ApSMax));
+  EXPECT_EQ(A0, APIntOps::avgCeilS(AmSMin, ApSMax));
+
+  EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(A0, ApSMax));
+  EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(A0, ApSMax));
+
+  EXPECT_EQ(ApSMax, APIntOps::avgFloorS(ApSMax, ApSMax));
+  EXPECT_EQ(ApSMax, APIntOps::avgCeilS(ApSMax, ApSMax));
+}
+
 TEST(APIntTest, umul_ov) {
   const std::pair<uint64_t, uint64_t> Overflows[] = {
       {0x8000000000000000, 2},


        


More information about the llvm-commits mailing list