[llvm] 8ae1cb2 - add power function to APInt (#122788)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 17 06:40:34 PST 2025


Author: Iman Hosseini
Date: 2025-01-17T14:40:31Z
New Revision: 8ae1cb2bcb55293cce31bb75c38d6b4e8a13cc23

URL: https://github.com/llvm/llvm-project/commit/8ae1cb2bcb55293cce31bb75c38d6b4e8a13cc23
DIFF: https://github.com/llvm/llvm-project/commit/8ae1cb2bcb55293cce31bb75c38d6b4e8a13cc23.diff

LOG: add power function to APInt (#122788)

I am trying to calculate power function for APFloat, APInt to constant
fold vector reductions: https://github.com/llvm/llvm-project/pull/122450
I need this utility to fold N `mul`s into power.

---------

Co-authored-by: ImanHosseini <imanhosseini.17 at gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    

Modified: 
    llvm/include/llvm/ADT/APInt.h
    llvm/lib/Support/APInt.cpp
    llvm/unittests/ADT/APIntTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 225390f1af60b3..02d58d8c3d31ce 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2263,6 +2263,10 @@ APInt mulhs(const APInt &C1, const APInt &C2);
 /// Returns the high N bits of the multiplication result.
 APInt mulhu(const APInt &C1, const APInt &C2);
 
+/// Compute X^N for N>=0.
+/// 0^0 is supported and returns 1.
+APInt pow(const APInt &X, int64_t N);
+
 /// Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values

diff  --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index ea8295f95c751a..38cf485733a937 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3108,3 +3108,21 @@ APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
   APInt C2Ext = C2.zext(FullWidth);
   return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
 }
+
+APInt APIntOps::pow(const APInt &X, int64_t N) {
+  assert(N >= 0 && "negative exponents not supported.");
+  APInt Acc = APInt(X.getBitWidth(), 1);
+  if (N == 0)
+    return Acc;
+  APInt Base = X;
+  int64_t RemainingExponent = N;
+  while (RemainingExponent > 0) {
+    while (RemainingExponent % 2 == 0) {
+      Base *= Base;
+      RemainingExponent /= 2;
+    }
+    --RemainingExponent;
+    Acc *= Base;
+  }
+  return Acc;
+};

diff  --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 4d5553fcbd1e3f..b14366eac2185a 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -29,6 +29,73 @@ TEST(APIntTest, ValueInit) {
   EXPECT_TRUE(!Zero.sext(64));
 }
 
+// Test that 0^5 == 0
+TEST(APIntTest, PowZeroTo5) {
+  APInt Zero = APInt::getZero(32);
+  EXPECT_TRUE(!Zero);
+  APInt ZeroTo5 = APIntOps::pow(Zero, 5);
+  EXPECT_TRUE(!ZeroTo5);
+}
+
+// Test that 1^16 == 1
+TEST(APIntTest, PowOneTo16) {
+  APInt One(32, 1);
+  APInt OneTo16 = APIntOps::pow(One, 16);
+  EXPECT_EQ(One, OneTo16);
+}
+
+// Test that 2^10 == 1024
+TEST(APIntTest, PowerTwoTo10) {
+  APInt Two(32, 2);
+  APInt TwoTo20 = APIntOps::pow(Two, 10);
+  APInt V_1024(32, 1024);
+  EXPECT_EQ(TwoTo20, V_1024);
+}
+
+// Test that 3^3 == 27
+TEST(APIntTest, PowerThreeTo3) {
+  APInt Three(32, 3);
+  APInt ThreeTo3 = APIntOps::pow(Three, 3);
+  APInt V_27(32, 27);
+  EXPECT_EQ(ThreeTo3, V_27);
+}
+
+// Test that SignedMaxValue^3 == SignedMaxValue
+TEST(APIntTest, PowerSignedMaxValue) {
+  APInt SignedMaxValue = APInt::getSignedMaxValue(32);
+  APInt MaxTo3 = APIntOps::pow(SignedMaxValue, 3);
+  EXPECT_EQ(MaxTo3, SignedMaxValue);
+}
+
+// Test that MaxValue^3 == MaxValue
+TEST(APIntTest, PowerMaxValue) {
+  APInt MaxValue = APInt::getMaxValue(32);
+  APInt MaxTo3 = APIntOps::pow(MaxValue, 3);
+  EXPECT_EQ(MaxValue, MaxTo3);
+}
+
+// Test that SignedMinValue^3 == 0
+TEST(APIntTest, PowerSignedMinValueTo3) {
+  APInt SignedMinValue = APInt::getSignedMinValue(32);
+  APInt MinTo3 = APIntOps::pow(SignedMinValue, 3);
+  EXPECT_TRUE(MinTo3.isZero());
+}
+
+// Test that SignedMinValue^1 == SignedMinValue
+TEST(APIntTest, PowerSignedMinValueTo1) {
+  APInt SignedMinValue = APInt::getSignedMinValue(32);
+  APInt MinTo1 = APIntOps::pow(SignedMinValue, 1);
+  EXPECT_EQ(SignedMinValue, MinTo1);
+}
+
+// Test that MaxValue^3 == MaxValue
+TEST(APIntTest, ZeroToZero) {
+  APInt Zero = APInt::getZero(32);
+  APInt One(32, 1);
+  APInt ZeroToZero = APIntOps::pow(Zero, 0);
+  EXPECT_EQ(ZeroToZero, One);
+}
+
 // Test that APInt shift left works when bitwidth > 64 and shiftamt == 0
 TEST(APIntTest, ShiftLeftByZero) {
   APInt One = APInt::getZero(65) + 1;


        


More information about the llvm-commits mailing list