[llvm] [ADT] Add signed and unsigned mulHi and mulLo to APInt (PR #84719)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Mar 10 22:02:11 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-adt
Author: Atousa Duprat (Atousa)
<details>
<summary>Changes</summary>
This addresses issue #<!-- -->84207
---
Full diff: https://github.com/llvm/llvm-project/pull/84719.diff
4 Files Affected:
- (modified) llvm/include/llvm/ADT/APInt.h (+12)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+5-13)
- (modified) llvm/lib/Support/APInt.cpp (+28)
- (modified) llvm/unittests/ADT/APIntTest.cpp (+84)
``````````diff
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..e3ce01af20e7c6 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}
+/// Return the high bits of the signed multiplication of C1 and C2
+APInt mulHiS(const APInt &C1, const APInt &C2);
+
+/// Return the high bits of the unsigned multiplication of C1 and C2
+APInt mulHiU(const APInt &C1, const APInt &C2);
+
+/// Return the low bits of the signed multiplication of C1 and C2
+APInt mulLoS(const APInt &C1, const APInt &C2);
+
+/// Return the low bits of the unsigned multiplication of C1 and C2
+APInt mulLoU(const APInt &C1, const APInt &C2);
+
/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..53697e1ffb5b71 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6009,18 +6009,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHS: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.sext(FullWidth);
- APInt C2Ext = C2.sext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
- case ISD::MULHU: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.zext(FullWidth);
- APInt C2Ext = C2.zext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
+ case ISD::MULHS:
+ return APIntOps::mulHiS(C1, C2);
+ case ISD::MULHU:
+ return APIntOps::mulHiU(C1, C2);
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
@@ -6706,8 +6698,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
case ISD::UDIV:
case ISD::UREM:
- case ISD::MULHU:
case ISD::MULHS:
+ case ISD::MULHU:
case ISD::SDIV:
case ISD::SREM:
case ISD::SADDSAT:
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..46c469cafa88cc 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3094,3 +3094,31 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
}
}
+
+APInt APIntOps::mulHiS(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulHiU(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulLoS(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).trunc(C1.getBitWidth());
+}
+
+APInt APIntOps::mulLoU(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).trunc(C1.getBitWidth());
+}
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..1597ac6f331d47 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -10,6 +10,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Alignment.h"
#include "gtest/gtest.h"
@@ -2805,6 +2806,89 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}
+TEST(APIntOpsTest, MulHiLo) {
+
+ // Unsigned
+
+ // 32 bits
+ APInt i32a(32, 0x0001'E235);
+ APInt i32b(32, 0xF623'55AD);
+ EXPECT_EQ(0x0001'CFA1, APIntOps::mulHiU(i32a, i32b));
+ EXPECT_EQ(0x7CA0'76D1, APIntOps::mulLoU(i32a, i32b));
+
+ // 64 bits
+ APInt i64a(64, 0x1234'5678'90AB'CDEF);
+ APInt i64b(64, 0xFEDC'BA09'8765'4321);
+ EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulHiU(i64a, i64b));
+ EXPECT_EQ(0xC24A'442F'E556'18CF, APIntOps::mulLoU(i64a, i64b));
+
+ // 128 bits
+ APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
+ APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
+ APInt i128ResHi = APIntOps::mulHiU(i128a, i128b);
+ std::string strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0x121F'A000'A372'3A57'E689'8431'2C3A'8D7E", strResHi.c_str());
+ APInt i128ResLo = APIntOps::mulLoU(i128a, i128b);
+ std::string strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0x96B4'2860'6E1E'6BF5'C24A'442F'E556'18CF", strResLo.c_str());
+
+ // Signed
+
+ // 32 bits
+ APInt i32c(32, 0x1234'5678); // +ve
+ APInt i32d(32, 0x10AB'CDEF); // +ve
+ APInt i32e(32, 0xFEDC'BA09); // -ve
+
+ EXPECT_EQ(0x012F'7D02, APIntOps::mulHiS(i32c, i32d));
+ EXPECT_EQ(0x2A42'D208, APIntOps::mulLoS(i32c, i32d));
+
+ EXPECT_EQ(0xFFEB'4988, APIntOps::mulHiS(i32c, i32e));
+ EXPECT_EQ(0x09CA'3A38, APIntOps::mulLoS(i32c, i32e));
+
+ EXPECT_EQ(0x0001'4B68, APIntOps::mulHiS(i32e, i32e));
+ EXPECT_EQ(0x22A9'1451, APIntOps::mulLoS(i32e, i32e));
+
+ // 64 bits
+ APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
+ APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
+ APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
+
+ EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulHiS(i64c, i64d));
+ EXPECT_EQ(0xFB99'7041'84EF'03A6, APIntOps::mulLoS(i64c, i64d));
+
+ EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulHiS(i64c, i64e));
+ EXPECT_EQ(0xC24A'442F'E556'18CF, APIntOps::mulLoS(i64c, i64e));
+
+ EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulHiS(i64e, i64e));
+ EXPECT_EQ(0xCEFE'A12C'D7A4'4A41, APIntOps::mulLoS(i64e, i64e));
+
+ // 128 bits
+ APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
+ APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
+ APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
+
+ i128ResHi = APIntOps::mulHiS(i128c, i128d);
+ strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0x14B'66DC'328E'10C1'FE30'3DF9'EA0B'2529", strResHi.c_str());
+ i128ResLo = APIntOps::mulLoS(i128c, i128d);
+ strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0xF87E'475F'3C6C'180D'FB99'7041'84EF'03A6", strResLo.c_str());
+
+ i128ResHi = APIntOps::mulHiS(i128c, i128e);
+ strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0xFFEB'4988'12C6'6C68'D455'2DB8'9B8E'BF8F", strResHi.c_str());
+ i128ResLo = APIntOps::mulLoS(i128c, i128e);
+ strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0x96B4'2860'6E1E'6BF5'C24A'442F'E556'18CF", strResLo.c_str());
+
+ i128ResHi = APIntOps::mulHiS(i128e, i128e);
+ strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0x1'4B68'2174'FA18'CCBA'AC10'2958'C4B5", strResHi.c_str());
+ i128ResLo = APIntOps::mulLoS(i128e, i128e);
+ strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0x9BB8'01D4'DF88'14DC'CEFE'A12C'D7A4'4A41", strResLo.c_str());
+}
+
TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
``````````
</details>
https://github.com/llvm/llvm-project/pull/84719
More information about the llvm-commits
mailing list