[llvm] [ADT] Add signed and unsigned mulHi and mulLo to APInt (PR #84719)
Atousa Duprat via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 12 00:52:05 PDT 2024
https://github.com/Atousa updated https://github.com/llvm/llvm-project/pull/84719
>From ea764b7e90d2bf854f87ee922bc0306aabeebcc0 Mon Sep 17 00:00:00 2001
From: Atousa Duprat <atousa.p at gmail.com>
Date: Sun, 10 Mar 2024 21:53:04 -0700
Subject: [PATCH] [ADT] Add signed and unsigned mulHi and mulLo to APInt
This addresses issue #84207
---
llvm/include/llvm/ADT/APInt.h | 6 +++
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 16 ++----
llvm/lib/Support/APInt.cpp | 16 ++++++
llvm/unittests/ADT/APIntTest.cpp | 53 +++++++++++++++++++
.../Support/DivisionByConstantTest.cpp | 18 ++-----
llvm/unittests/Support/KnownBitsTest.cpp | 10 +---
6 files changed, 84 insertions(+), 35 deletions(-)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..0f075fa7140d7e 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}
+/// Return the high bits of the signed multiplication of C1 and C2
+APInt mulhs(const APInt &C1, const APInt &C2);
+
+/// Return the high bits of the unsigned multiplication of C1 and C2
+APInt mulhu(const APInt &C1, const APInt &C2);
+
/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..913e14ee4bfac4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6009,18 +6009,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHS: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.sext(FullWidth);
- APInt C2Ext = C2.sext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
- case ISD::MULHU: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.zext(FullWidth);
- APInt C2Ext = C2.zext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
+ case ISD::MULHS:
+ return APIntOps::mulhs(C1, C2);
+ case ISD::MULHU:
+ return APIntOps::mulhu(C1, C2);
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..ae8d1a2a13911e 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3094,3 +3094,19 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
}
}
+
+APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
+ assert(C1.getBitWidth() >= C2.getBitWidth());
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
+ assert(C1.getBitWidth() >= C2.getBitWidth());
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..b6eedb31a277b3 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -10,6 +10,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Alignment.h"
#include "gtest/gtest.h"
@@ -2805,6 +2806,58 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}
+TEST(APIntOpsTest, Mulh) {
+
+ // Unsigned
+
+ // 32 bits
+ APInt i32a(32, 0x0001'E235);
+ APInt i32b(32, 0xF623'55AD);
+ EXPECT_EQ(0x0001'CFA1, APIntOps::mulhu(i32a, i32b));
+
+ // 64 bits
+ APInt i64a(64, 0x1234'5678'90AB'CDEF);
+ APInt i64b(64, 0xFEDC'BA09'8765'4321);
+ EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulhu(i64a, i64b));
+
+ // 128 bits
+ APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
+ APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
+ APInt i128Res = APIntOps::mulhu(i128a, i128b);
+ EXPECT_EQ(APInt(128, "121FA000A3723A57E68984312C3A8D7E", 16), i128Res);
+
+ // Signed
+
+ // 32 bits
+ APInt i32c(32, 0x1234'5678); // +ve
+ APInt i32d(32, 0x10AB'CDEF); // +ve
+ APInt i32e(32, 0xFEDC'BA09); // -ve
+
+ EXPECT_EQ(0x012F'7D02, APIntOps::mulhs(i32c, i32d));
+ EXPECT_EQ(0xFFEB'4988, APIntOps::mulhs(i32c, i32e));
+ EXPECT_EQ(0x0001'4B68, APIntOps::mulhs(i32e, i32e));
+
+ // 64 bits
+ APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
+ APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
+ APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
+
+ EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulhs(i64c, i64d));
+ EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulhs(i64c, i64e));
+ EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulhs(i64e, i64e));
+
+ // 128 bits
+ APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
+ APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
+ APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
+
+ i128Res = APIntOps::mulhs(i128c, i128d);
+ EXPECT_EQ(APInt(128, "14B66DC328E10C1FE303DF9EA0B2529", 16), i128Res);
+
+ i128Res = APIntOps::mulhs(i128c, i128e);
+ EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res);
+}
+
TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
index 2b17f98bb75b2f..7fd8a6cd34f579 100644
--- a/llvm/unittests/Support/DivisionByConstantTest.cpp
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
} while (++N != 0);
}
-APInt MULHS(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
SignedDivisionByConstantInfo Magics) {
unsigned Bits = Numerator.getBitWidth();
@@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
}
// Multiply the numerator by the magic value.
- APInt Q = MULHS(Numerator, Magics.Magic);
+ APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);
// (Optionally) Add/subtract the numerator using Factor.
Factor = Numerator * Factor;
@@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
}
}
-APInt MULHU(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
bool LZOptimization,
bool AllowEvenDivisorOptimization, bool ForceNPQ,
@@ -129,7 +117,7 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
APInt Q = Numerator.lshr(PreShift);
// Multiply the numerator by the magic value.
- Q = MULHU(Q, Magics.Magic);
+ Q = APIntOps::mulhu(Q, Magics.Magic);
if (UseNPQ || ForceNPQ) {
APInt NPQ = Numerator - Q;
@@ -138,7 +126,7 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
// MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
APInt NPQ_Scalar = NPQ.lshr(1);
(void)NPQ_Scalar;
- NPQ = MULHU(NPQ, NPQFactor);
+ NPQ = APIntOps::mulhu(NPQ, NPQFactor);
assert(!UseNPQ || NPQ == NPQ_Scalar);
Q = NPQ + Q;
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..65bb228cbc73c3 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,19 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::mulhs(Known1, Known2);
},
- [](const APInt &N1, const APInt &N2) {
- unsigned Bits = N1.getBitWidth();
- return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
- },
+ [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::mulhu(Known1, Known2);
},
- [](const APInt &N1, const APInt &N2) {
- unsigned Bits = N1.getBitWidth();
- return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
- },
+ [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
checkCorrectnessOnlyBinary);
}
More information about the llvm-commits
mailing list