[llvm] [ADT] Add implementations for mulhs and mulhu to APInt (PR #84609)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 8 22:22:24 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-adt

Author: Shourya Goel (Sh0g0-1758)

<details>
<summary>Changes</summary>

Fixes: #<!-- -->84207

---
Full diff: https://github.com/llvm/llvm-project/pull/84609.diff


5 Files Affected:

- (modified) llvm/include/llvm/ADT/APInt.h (+6) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+4-10) 
- (modified) llvm/lib/Support/APInt.cpp (+16) 
- (modified) llvm/unittests/ADT/APIntTest.cpp (+20) 
- (modified) llvm/unittests/Support/DivisionByConstantTest.cpp (+4-16) 


``````````diff
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..711eb3a9c3fc6d 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
   return A.uge(B) ? (A - B) : (B - A);
 }
 
+/// Compute the higher order bits of unsigned multiplication of two APInts
+APInt mulhu(const APInt &C1, const APInt &C2);
+
+/// Compute the higher order bits of signed multiplication of two APInts
+APInt mulhs(const APInt &C1, const APInt &C2);
+
 /// Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 50f53bbb04b62d..4f533b7d055129 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,17 +6015,11 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     if (!C2.getBoolValue())
       break;
     return C1.srem(C2);
-  case ISD::MULHS: {
-    unsigned FullWidth = C1.getBitWidth() * 2;
-    APInt C1Ext = C1.sext(FullWidth);
-    APInt C2Ext = C2.sext(FullWidth);
-    return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
-  }
   case ISD::MULHU: {
-    unsigned FullWidth = C1.getBitWidth() * 2;
-    APInt C1Ext = C1.zext(FullWidth);
-    APInt C2Ext = C2.zext(FullWidth);
-    return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+    return APIntOps::mulhu(C1, C2);
+  }
+  case ISD::MULHS: {
+    return APIntOps::mulhs(C1, C2);
   }
   case ISD::AVGFLOORS: {
     unsigned FullWidth = C1.getBitWidth() + 1;
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..9ce10ada67e9e1 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
   }
 }
 
+APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
+  // Return higher order bits for unsigned (C1 * C2)
+  unsigned FullWidth = C1.getBitWidth() * 2;
+  APInt C1Ext = C1.zext(FullWidth);
+  APInt C2Ext = C2.zext(FullWidth);
+  return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
+  // Return higher order bits for signed (C1 * C2)
+  unsigned FullWidth = C1.getBitWidth() * 2;
+  APInt C1Ext = C1.sext(FullWidth);
+  APInt C2Ext = C2.sext(FullWidth);
+  return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
 /// LoadIntFromMemory - Loads the integer stored in the LoadBytes bytes starting
 /// from Src into IntVal, which is assumed to be wide enough and to hold zero.
 void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..9995aeaff92871 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2805,6 +2805,26 @@ TEST(APIntTest, multiply) {
   EXPECT_EQ(64U, i96.countr_zero());
 }
 
+TEST(APIntTest, Hmultiply) {
+  APInt i1048576(32, 1048576);
+
+  EXPECT_EQ(APInt(32, 256), APIntOps::mulhu(i1048576, i1048576));
+
+  APInt i16777216(32, 16777216);
+  APInt i32768(32, 32768);
+
+  EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i16777216, i32768));
+  EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i32768, i16777216));
+
+  APInt i2097152(32, -2097152);
+  APInt i524288(32, 524288);
+
+  EXPECT_EQ(APInt(32, 1024), APIntOps::mulhs(i2097152, i2097152));
+
+  EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i2097152, i524288));
+  EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i524288, i2097152));
+}
+
 TEST(APIntTest, RoundingUDiv) {
   for (uint64_t Ai = 1; Ai <= 255; Ai++) {
     APInt A(8, Ai);
diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
index 2b17f98bb75b2f..8e0c78fe85654a 100644
--- a/llvm/unittests/Support/DivisionByConstantTest.cpp
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
   } while (++N != 0);
 }
 
-APInt MULHS(APInt X, APInt Y) {
-  unsigned Bits = X.getBitWidth();
-  unsigned WideBits = 2 * Bits;
-  return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
                              SignedDivisionByConstantInfo Magics) {
   unsigned Bits = Numerator.getBitWidth();
@@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
   }
 
   // Multiply the numerator by the magic value.
-  APInt Q = MULHS(Numerator, Magics.Magic);
+  APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);
 
   // (Optionally) Add/subtract the numerator using Factor.
   Factor = Numerator * Factor;
@@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
   }
 }
 
-APInt MULHU(APInt X, APInt Y) {
-  unsigned Bits = X.getBitWidth();
-  unsigned WideBits = 2 * Bits;
-  return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
                                bool LZOptimization,
                                bool AllowEvenDivisorOptimization, bool ForceNPQ,
@@ -129,16 +117,16 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
   APInt Q = Numerator.lshr(PreShift);
 
   // Multiply the numerator by the magic value.
-  Q = MULHU(Q, Magics.Magic);
+  Q = APIntOps::mulhu(Q, Magics.Magic);
 
   if (UseNPQ || ForceNPQ) {
     APInt NPQ = Numerator - Q;
 
     // For vectors we might have a mix of non-NPQ/NPQ paths, so use
-    // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
+    // mulhu to act as a SRL-by-1 for NPQ, else multiply by zero.
     APInt NPQ_Scalar = NPQ.lshr(1);
     (void)NPQ_Scalar;
-    NPQ = MULHU(NPQ, NPQFactor);
+    NPQ = APIntOps::mulhu(NPQ, NPQFactor);
     assert(!UseNPQ || NPQ == NPQ_Scalar);
 
     Q = NPQ + Q;

``````````

</details>


https://github.com/llvm/llvm-project/pull/84609


More information about the llvm-commits mailing list