[llvm] [ADT] Add implementations for mulhs and mulhu to APInt (PR #84609)
Shourya Goel via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 9 04:32:06 PST 2024
https://github.com/Sh0g0-1758 updated https://github.com/llvm/llvm-project/pull/84609
>From 547887311ada7a38cc49d58668ad264bbbd7d64d Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 11:40:32 +0530
Subject: [PATCH 1/6] Add APIntOps::mulhs / APIntOps::mulhu
---
llvm/include/llvm/ADT/APInt.h | 6 ++++++
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 6 ++++++
llvm/lib/Support/APInt.cpp | 16 +++++++++++++++
llvm/unittests/ADT/APIntTest.cpp | 20 +++++++++++++++++++
.../Support/DivisionByConstantTest.cpp | 20 ++++---------------
5 files changed, 52 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..711eb3a9c3fc6d 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}
+/// Compute the higher order bits of unsigned multiplication of two APInts
+APInt mulhu(const APInt &C1, const APInt &C2);
+
+/// Compute the higher order bits of signed multiplication of two APInts
+APInt mulhs(const APInt &C1, const APInt &C2);
+
/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 50f53bbb04b62d..054fe4ec37da24 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6027,6 +6027,12 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}
+ case ISD::MULHU: {
+ return APIntOps::mulhu(C1, C2);
+ }
+ case ISD::MULHS: {
+ return APIntOps::mulhs(C1, C2);
+ }
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..9ce10ada67e9e1 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
}
}
+APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
+ // Return higher order bits for unsigned (C1 * C2)
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
+ // Return higher order bits for signed (C1 * C2)
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
/// LoadIntFromMemory - Loads the integer stored in the LoadBytes bytes starting
/// from Src into IntVal, which is assumed to be wide enough and to hold zero.
void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..9995aeaff92871 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2805,6 +2805,26 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}
+TEST(APIntTest, Hmultiply) {
+ APInt i1048576(32, 1048576);
+
+ EXPECT_EQ(APInt(32, 256), APIntOps::mulhu(i1048576, i1048576));
+
+ APInt i16777216(32, 16777216);
+ APInt i32768(32, 32768);
+
+ EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i16777216, i32768));
+ EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i32768, i16777216));
+
+ APInt i2097152(32, -2097152);
+ APInt i524288(32, 524288);
+
+ EXPECT_EQ(APInt(32, 1024), APIntOps::mulhs(i2097152, i2097152));
+
+ EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i2097152, i524288));
+ EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i524288, i2097152));
+}
+
TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
index 2b17f98bb75b2f..8e0c78fe85654a 100644
--- a/llvm/unittests/Support/DivisionByConstantTest.cpp
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
} while (++N != 0);
}
-APInt MULHS(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
SignedDivisionByConstantInfo Magics) {
unsigned Bits = Numerator.getBitWidth();
@@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
}
// Multiply the numerator by the magic value.
- APInt Q = MULHS(Numerator, Magics.Magic);
+ APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);
// (Optionally) Add/subtract the numerator using Factor.
Factor = Numerator * Factor;
@@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
}
}
-APInt MULHU(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
bool LZOptimization,
bool AllowEvenDivisorOptimization, bool ForceNPQ,
@@ -129,16 +117,16 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
APInt Q = Numerator.lshr(PreShift);
// Multiply the numerator by the magic value.
- Q = MULHU(Q, Magics.Magic);
+ Q = APIntOps::mulhu(Q, Magics.Magic);
if (UseNPQ || ForceNPQ) {
APInt NPQ = Numerator - Q;
// For vectors we might have a mix of non-NPQ/NPQ paths, so use
- // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
+ // mulhu to act as a SRL-by-1 for NPQ, else multiply by zero.
APInt NPQ_Scalar = NPQ.lshr(1);
(void)NPQ_Scalar;
- NPQ = MULHU(NPQ, NPQFactor);
+ NPQ = APIntOps::mulhu(NPQ, NPQFactor);
assert(!UseNPQ || NPQ == NPQ_Scalar);
Q = NPQ + Q;
>From abdc2f8bb263de27dc575b5dda53b85f2b415365 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 11:50:51 +0530
Subject: [PATCH 2/6] Removed older case statements
---
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 12 ------------
1 file changed, 12 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 054fe4ec37da24..4f533b7d055129 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,18 +6015,6 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHS: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.sext(FullWidth);
- APInt C2Ext = C2.sext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
- case ISD::MULHU: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.zext(FullWidth);
- APInt C2Ext = C2.zext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
case ISD::MULHU: {
return APIntOps::mulhu(C1, C2);
}
>From c81a474ab12643ed9373717f2afa2ce614fdd0b9 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 16:08:33 +0530
Subject: [PATCH 3/6] Removed Braces
---
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 8 ++------
1 file changed, 2 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 4f533b7d055129..da468f3a97eaf8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,12 +6015,8 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHU: {
- return APIntOps::mulhu(C1, C2);
- }
- case ISD::MULHS: {
- return APIntOps::mulhs(C1, C2);
- }
+ case ISD::MULHU: return APIntOps::mulhu(C1, C2);
+ case ISD::MULHS: return APIntOps::mulhs(C1, C2);
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
>From bd7cd462db4cf7da882d3a4eeec34e36676aecc0 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 16:09:00 +0530
Subject: [PATCH 4/6] Ran clang formatter
---
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index da468f3a97eaf8..e1fcb6f84ede2b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,8 +6015,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHU: return APIntOps::mulhu(C1, C2);
- case ISD::MULHS: return APIntOps::mulhs(C1, C2);
+ case ISD::MULHU:
+ return APIntOps::mulhu(C1, C2);
+ case ISD::MULHS:
+ return APIntOps::mulhs(C1, C2);
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
>From 32ec3436885452011a563157b73d8e42e31aea1e Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 18:01:27 +0530
Subject: [PATCH 5/6] Refactored KnownBitsTest
---
llvm/unittests/Support/KnownBitsTest.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..40ea8ce0b597a6 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -538,8 +538,7 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return KnownBits::mulhs(Known1, Known2);
},
[](const APInt &N1, const APInt &N2) {
- unsigned Bits = N1.getBitWidth();
- return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
+ return APIntOps::mulhs(N1, N2);
},
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
@@ -547,8 +546,7 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return KnownBits::mulhu(Known1, Known2);
},
[](const APInt &N1, const APInt &N2) {
- unsigned Bits = N1.getBitWidth();
- return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
+ return APIntOps::mulhu(N1, N2);
},
checkCorrectnessOnlyBinary);
}
>From 3b7b3a48040cf478cb0b659d0d2e6a893a306832 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 18:01:48 +0530
Subject: [PATCH 6/6] Ran clang formatter
---
llvm/unittests/Support/KnownBitsTest.cpp | 8 ++------
1 file changed, 2 insertions(+), 6 deletions(-)
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 40ea8ce0b597a6..65bb228cbc73c3 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,17 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::mulhs(Known1, Known2);
},
- [](const APInt &N1, const APInt &N2) {
- return APIntOps::mulhs(N1, N2);
- },
+ [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::mulhu(Known1, Known2);
},
- [](const APInt &N1, const APInt &N2) {
- return APIntOps::mulhu(N1, N2);
- },
+ [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
checkCorrectnessOnlyBinary);
}
More information about the llvm-commits
mailing list