[llvm] [mlir] [ADT] Add implementations for mulhs and mulhu to APInt (PR #84609)

Shourya Goel via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 9 13:50:21 PST 2024


https://github.com/Sh0g0-1758 updated https://github.com/llvm/llvm-project/pull/84609

>From 547887311ada7a38cc49d58668ad264bbbd7d64d Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 11:40:32 +0530
Subject: [PATCH 01/21] Add APIntOps::mulhs / APIntOps::mulhu

---
 llvm/include/llvm/ADT/APInt.h                 |  6 ++++++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  6 ++++++
 llvm/lib/Support/APInt.cpp                    | 16 +++++++++++++++
 llvm/unittests/ADT/APIntTest.cpp              | 20 +++++++++++++++++++
 .../Support/DivisionByConstantTest.cpp        | 20 ++++---------------
 5 files changed, 52 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..711eb3a9c3fc6d 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
   return A.uge(B) ? (A - B) : (B - A);
 }
 
+/// Compute the higher order bits of unsigned multiplication of two APInts
+APInt mulhu(const APInt &C1, const APInt &C2);
+
+/// Compute the higher order bits of signed multiplication of two APInts
+APInt mulhs(const APInt &C1, const APInt &C2);
+
 /// Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 50f53bbb04b62d..054fe4ec37da24 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6027,6 +6027,12 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     APInt C2Ext = C2.zext(FullWidth);
     return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
   }
+  case ISD::MULHU: {
+    return APIntOps::mulhu(C1, C2);
+  }
+  case ISD::MULHS: {
+    return APIntOps::mulhs(C1, C2);
+  }
   case ISD::AVGFLOORS: {
     unsigned FullWidth = C1.getBitWidth() + 1;
     APInt C1Ext = C1.sext(FullWidth);
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..9ce10ada67e9e1 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
   }
 }
 
+APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
+  // Return higher order bits for unsigned (C1 * C2)
+  unsigned FullWidth = C1.getBitWidth() * 2;
+  APInt C1Ext = C1.zext(FullWidth);
+  APInt C2Ext = C2.zext(FullWidth);
+  return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
+  // Return higher order bits for signed (C1 * C2)
+  unsigned FullWidth = C1.getBitWidth() * 2;
+  APInt C1Ext = C1.sext(FullWidth);
+  APInt C2Ext = C2.sext(FullWidth);
+  return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
 /// LoadIntFromMemory - Loads the integer stored in the LoadBytes bytes starting
 /// from Src into IntVal, which is assumed to be wide enough and to hold zero.
 void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..9995aeaff92871 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2805,6 +2805,26 @@ TEST(APIntTest, multiply) {
   EXPECT_EQ(64U, i96.countr_zero());
 }
 
+TEST(APIntTest, Hmultiply) {
+  APInt i1048576(32, 1048576);
+
+  EXPECT_EQ(APInt(32, 256), APIntOps::mulhu(i1048576, i1048576));
+
+  APInt i16777216(32, 16777216);
+  APInt i32768(32, 32768);
+
+  EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i16777216, i32768));
+  EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i32768, i16777216));
+
+  APInt i2097152(32, -2097152);
+  APInt i524288(32, 524288);
+
+  EXPECT_EQ(APInt(32, 1024), APIntOps::mulhs(i2097152, i2097152));
+
+  EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i2097152, i524288));
+  EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i524288, i2097152));
+}
+
 TEST(APIntTest, RoundingUDiv) {
   for (uint64_t Ai = 1; Ai <= 255; Ai++) {
     APInt A(8, Ai);
diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
index 2b17f98bb75b2f..8e0c78fe85654a 100644
--- a/llvm/unittests/Support/DivisionByConstantTest.cpp
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
   } while (++N != 0);
 }
 
-APInt MULHS(APInt X, APInt Y) {
-  unsigned Bits = X.getBitWidth();
-  unsigned WideBits = 2 * Bits;
-  return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
                              SignedDivisionByConstantInfo Magics) {
   unsigned Bits = Numerator.getBitWidth();
@@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
   }
 
   // Multiply the numerator by the magic value.
-  APInt Q = MULHS(Numerator, Magics.Magic);
+  APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);
 
   // (Optionally) Add/subtract the numerator using Factor.
   Factor = Numerator * Factor;
@@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
   }
 }
 
-APInt MULHU(APInt X, APInt Y) {
-  unsigned Bits = X.getBitWidth();
-  unsigned WideBits = 2 * Bits;
-  return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
                                bool LZOptimization,
                                bool AllowEvenDivisorOptimization, bool ForceNPQ,
@@ -129,16 +117,16 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
   APInt Q = Numerator.lshr(PreShift);
 
   // Multiply the numerator by the magic value.
-  Q = MULHU(Q, Magics.Magic);
+  Q = APIntOps::mulhu(Q, Magics.Magic);
 
   if (UseNPQ || ForceNPQ) {
     APInt NPQ = Numerator - Q;
 
     // For vectors we might have a mix of non-NPQ/NPQ paths, so use
-    // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
+    // mulhu to act as a SRL-by-1 for NPQ, else multiply by zero.
     APInt NPQ_Scalar = NPQ.lshr(1);
     (void)NPQ_Scalar;
-    NPQ = MULHU(NPQ, NPQFactor);
+    NPQ = APIntOps::mulhu(NPQ, NPQFactor);
     assert(!UseNPQ || NPQ == NPQ_Scalar);
 
     Q = NPQ + Q;

>From abdc2f8bb263de27dc575b5dda53b85f2b415365 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 11:50:51 +0530
Subject: [PATCH 02/21] Removed older case statements

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 12 ------------
 1 file changed, 12 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 054fe4ec37da24..4f533b7d055129 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,18 +6015,6 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     if (!C2.getBoolValue())
       break;
     return C1.srem(C2);
-  case ISD::MULHS: {
-    unsigned FullWidth = C1.getBitWidth() * 2;
-    APInt C1Ext = C1.sext(FullWidth);
-    APInt C2Ext = C2.sext(FullWidth);
-    return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
-  }
-  case ISD::MULHU: {
-    unsigned FullWidth = C1.getBitWidth() * 2;
-    APInt C1Ext = C1.zext(FullWidth);
-    APInt C2Ext = C2.zext(FullWidth);
-    return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
-  }
   case ISD::MULHU: {
     return APIntOps::mulhu(C1, C2);
   }

>From c81a474ab12643ed9373717f2afa2ce614fdd0b9 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 16:08:33 +0530
Subject: [PATCH 03/21] Removed Braces

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 4f533b7d055129..da468f3a97eaf8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,12 +6015,8 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     if (!C2.getBoolValue())
       break;
     return C1.srem(C2);
-  case ISD::MULHU: {
-    return APIntOps::mulhu(C1, C2);
-  }
-  case ISD::MULHS: {
-    return APIntOps::mulhs(C1, C2);
-  }
+  case ISD::MULHU: return APIntOps::mulhu(C1, C2);
+  case ISD::MULHS: return APIntOps::mulhs(C1, C2);
   case ISD::AVGFLOORS: {
     unsigned FullWidth = C1.getBitWidth() + 1;
     APInt C1Ext = C1.sext(FullWidth);

>From bd7cd462db4cf7da882d3a4eeec34e36676aecc0 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 16:09:00 +0530
Subject: [PATCH 04/21] Ran clang formatter

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index da468f3a97eaf8..e1fcb6f84ede2b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,8 +6015,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     if (!C2.getBoolValue())
       break;
     return C1.srem(C2);
-  case ISD::MULHU: return APIntOps::mulhu(C1, C2);
-  case ISD::MULHS: return APIntOps::mulhs(C1, C2);
+  case ISD::MULHU:
+    return APIntOps::mulhu(C1, C2);
+  case ISD::MULHS:
+    return APIntOps::mulhs(C1, C2);
   case ISD::AVGFLOORS: {
     unsigned FullWidth = C1.getBitWidth() + 1;
     APInt C1Ext = C1.sext(FullWidth);

>From 32ec3436885452011a563157b73d8e42e31aea1e Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 18:01:27 +0530
Subject: [PATCH 05/21] Refactored KnownBitsTest

---
 llvm/unittests/Support/KnownBitsTest.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..40ea8ce0b597a6 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -538,8 +538,7 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return KnownBits::mulhs(Known1, Known2);
       },
       [](const APInt &N1, const APInt &N2) {
-        unsigned Bits = N1.getBitWidth();
-        return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
+        return APIntOps::mulhs(N1, N2);
       },
       checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
@@ -547,8 +546,7 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return KnownBits::mulhu(Known1, Known2);
       },
       [](const APInt &N1, const APInt &N2) {
-        unsigned Bits = N1.getBitWidth();
-        return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
+        return APIntOps::mulhu(N1, N2);
       },
       checkCorrectnessOnlyBinary);
 }

>From 3b7b3a48040cf478cb0b659d0d2e6a893a306832 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 18:01:48 +0530
Subject: [PATCH 06/21] Ran clang formatter

---
 llvm/unittests/Support/KnownBitsTest.cpp | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 40ea8ce0b597a6..65bb228cbc73c3 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,17 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) {
-        return APIntOps::mulhs(N1, N2);
-      },
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
       checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) {
-        return APIntOps::mulhu(N1, N2);
-      },
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       checkCorrectnessOnlyBinary);
 }
 

>From 2d6804a5e45782402558d6caaf548dfd5de1cbff Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 21:52:39 +0530
Subject: [PATCH 07/21] added same-bitwidth assertion

---
 llvm/lib/Support/APInt.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 9ce10ada67e9e1..df20d51acbae38 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3069,6 +3069,7 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
 
 APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
   // Return higher order bits for unsigned (C1 * C2)
+  assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");
   unsigned FullWidth = C1.getBitWidth() * 2;
   APInt C1Ext = C1.zext(FullWidth);
   APInt C2Ext = C2.zext(FullWidth);
@@ -3077,6 +3078,7 @@ APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
 
 APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
   // Return higher order bits for signed (C1 * C2)
+  assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");
   unsigned FullWidth = C1.getBitWidth() * 2;
   APInt C1Ext = C1.sext(FullWidth);
   APInt C2Ext = C2.sext(FullWidth);

>From b80b94de77a1c5fd50ebfaad38bb3bd84cecf668 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sat, 9 Mar 2024 22:14:19 +0530
Subject: [PATCH 08/21] Minor Bug Fix

---
 llvm/lib/Support/APInt.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index df20d51acbae38..a1b396d184b497 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3069,7 +3069,7 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
 
 APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
   // Return higher order bits for unsigned (C1 * C2)
-  assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Bit widths must be the same");
   unsigned FullWidth = C1.getBitWidth() * 2;
   APInt C1Ext = C1.zext(FullWidth);
   APInt C2Ext = C2.zext(FullWidth);
@@ -3078,7 +3078,7 @@ APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
 
 APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
   // Return higher order bits for signed (C1 * C2)
-  assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Bit widths must be the same");
   unsigned FullWidth = C1.getBitWidth() * 2;
   APInt C1Ext = C1.sext(FullWidth);
   APInt C2Ext = C2.sext(FullWidth);

>From 2bd2d18fa54d9eb91efa397fd2f864a7ee9a7994 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 00:09:08 +0530
Subject: [PATCH 09/21] revert

---
 llvm/lib/Support/APInt.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index a1b396d184b497..9ce10ada67e9e1 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3069,7 +3069,6 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
 
 APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
   // Return higher order bits for unsigned (C1 * C2)
-  assert(C1.getBitWidth() == C2.getBitWidth() && "Bit widths must be the same");
   unsigned FullWidth = C1.getBitWidth() * 2;
   APInt C1Ext = C1.zext(FullWidth);
   APInt C2Ext = C2.zext(FullWidth);
@@ -3078,7 +3077,6 @@ APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
 
 APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
   // Return higher order bits for signed (C1 * C2)
-  assert(C1.getBitWidth() == C2.getBitWidth() && "Bit widths must be the same");
   unsigned FullWidth = C1.getBitWidth() * 2;
   APInt C1Ext = C1.sext(FullWidth);
   APInt C2Ext = C2.sext(FullWidth);

>From 0d3edd5d6d00968ec1a8e6b38917fd5551491e30 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 02:21:01 +0530
Subject: [PATCH 10/21] Added Docs and Refactored Functions

---
 llvm/include/llvm/ADT/APInt.h            | 4 ++++
 llvm/unittests/Support/KnownBitsTest.cpp | 4 ++--
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 711eb3a9c3fc6d..e104aceb0b32a5 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2194,9 +2194,13 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
 }
 
 /// Compute the higher order bits of unsigned multiplication of two APInts
+/// Mathematically, this computes the value: (C1 * C2) >> C2.getBitWidth()
+/// where (C1 * C2) has double the bit width of the original values.
 APInt mulhu(const APInt &C1, const APInt &C2);
 
 /// Compute the higher order bits of signed multiplication of two APInts
+/// Mathematically, this is similar to mulhu but for signed values.
+/// Example: mulhs(-2097152,524288) == -256
 APInt mulhs(const APInt &C1, const APInt &C2);
 
 /// Compute GCD of two unsigned APInt values.
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 65bb228cbc73c3..b4ecfa97989e3b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,13 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
+      &APIntOps::mulhs,
       checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
+      &APIntOps::mulhu,
       checkCorrectnessOnlyBinary);
 }
 

>From 6fb4b4a904520858a84e1ebef33578ded142df02 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 02:23:26 +0530
Subject: [PATCH 11/21] Ran clang Formatter

---
 llvm/unittests/Support/KnownBitsTest.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b4ecfa97989e3b..e44775c1de6aaa 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,14 +537,12 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      &APIntOps::mulhs,
-      checkCorrectnessOnlyBinary);
+      &APIntOps::mulhs, checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      &APIntOps::mulhu,
-      checkCorrectnessOnlyBinary);
+      &APIntOps::mulhu, checkCorrectnessOnlyBinary);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {

>From 02d896e1c122b10f61f78e24512f991adeead3d0 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 02:32:33 +0530
Subject: [PATCH 12/21] Updated MLIR files

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp              | 8 ++------
 mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 7 ++-----
 2 files changed, 4 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..82908988e0fb5a 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
         adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          unsigned bitWidth = a.getBitWidth();
-          APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
-          return fullProduct.extractBits(bitWidth, bitWidth);
+          return APIntOps::mulhs(a,b);
         });
     assert(highAttr && "Unexpected constant-folding failure");
 
@@ -491,9 +489,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
         adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          unsigned bitWidth = a.getBitWidth();
-          APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
-          return fullProduct.extractBits(bitWidth, bitWidth);
+          return APIntOps::mulhu(a,b);
         });
     assert(highAttr && "Unexpected constant-folding failure");
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4c62289a1e9458..35b248235b4b43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
 
     auto highBits = constFoldBinaryOp<IntegerAttr>(
         {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
-          unsigned bitWidth = a.getBitWidth();
-          APInt c;
           if (IsSigned) {
-            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+            return APIntOps::mulhs(a,b);
           } else {
-            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+            return APIntOps::mulhu(a,b);
           }
-          return c.extractBits(bitWidth, bitWidth); // Extract high result
         });
 
     if (!highBits)

>From abce2ad8d5fe8b60ba72458ebb6075519b5c6255 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 02:32:52 +0530
Subject: [PATCH 13/21] Ran clang Formatter

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp              | 10 ++++------
 mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp |  4 ++--
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 82908988e0fb5a..271a0a3053a18f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -433,9 +433,8 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          return APIntOps::mulhs(a,b);
-        });
+        adaptor.getOperands(),
+        [](const APInt &a, const APInt &b) { return APIntOps::mulhs(a, b); });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
@@ -488,9 +487,8 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          return APIntOps::mulhu(a,b);
-        });
+        adaptor.getOperands(),
+        [](const APInt &a, const APInt &b) { return APIntOps::mulhu(a, b); });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 35b248235b4b43..2c28905dc48aef 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -251,9 +251,9 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
     auto highBits = constFoldBinaryOp<IntegerAttr>(
         {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
           if (IsSigned) {
-            return APIntOps::mulhs(a,b);
+            return APIntOps::mulhs(a, b);
           } else {
-            return APIntOps::mulhu(a,b);
+            return APIntOps::mulhu(a, b);
           }
         });
 

>From 6816c6a2cc92f4ddea9d767b12b96e423042cca7 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 02:54:42 +0530
Subject: [PATCH 14/21] Add namespace

---
 llvm/unittests/Support/KnownBitsTest.cpp            | 6 ++++--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp              | 4 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 4 ++--
 3 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index e44775c1de6aaa..65bb228cbc73c3 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,12 +537,14 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      &APIntOps::mulhs, checkCorrectnessOnlyBinary);
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
+      checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      &APIntOps::mulhu, checkCorrectnessOnlyBinary);
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
+      checkCorrectnessOnlyBinary);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 271a0a3053a18f..5c52511c50f3f5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -434,7 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
         adaptor.getOperands(),
-        [](const APInt &a, const APInt &b) { return APIntOps::mulhs(a, b); });
+        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhs(a, b); });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
@@ -488,7 +488,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
         adaptor.getOperands(),
-        [](const APInt &a, const APInt &b) { return APIntOps::mulhu(a, b); });
+        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhu(a, b); });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 2c28905dc48aef..eb1e97e7ecc908 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -251,9 +251,9 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
     auto highBits = constFoldBinaryOp<IntegerAttr>(
         {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
           if (IsSigned) {
-            return APIntOps::mulhs(a, b);
+            return llvm::APIntOps::mulhs(a, b);
           } else {
-            return APIntOps::mulhu(a, b);
+            return llvm::APIntOps::mulhu(a, b);
           }
         });
 

>From 4f0a94a5209d1c010d2f633590362ee1c377b687 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 02:55:02 +0530
Subject: [PATCH 15/21] Ran Clang Formatter

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5c52511c50f3f5..c705051f0f440e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -433,8 +433,9 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(),
-        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhs(a, b); });
+        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+          return llvm::APIntOps::mulhs(a, b);
+        });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
@@ -487,8 +488,9 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(),
-        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhu(a, b); });
+        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+          return llvm::APIntOps::mulhu(a, b);
+        });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);

>From 6e2450514c5b1166ce3dbe83ba7447c17e7dc744 Mon Sep 17 00:00:00 2001
From: Shourya Goel <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 03:00:58 +0530
Subject: [PATCH 16/21] Update llvm/include/llvm/ADT/APInt.h

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 llvm/include/llvm/ADT/APInt.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index e104aceb0b32a5..b60aec92eaa720 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2194,8 +2194,8 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
 }
 
 /// Compute the higher order bits of unsigned multiplication of two APInts
-/// Mathematically, this computes the value: (C1 * C2) >> C2.getBitWidth()
-/// where (C1 * C2) has double the bit width of the original values.
+/// Mathematically, this computes the value: `(C1 * C2) >> C2.getBitWidth()`
+/// where `(C1 * C2)` has double the bit width of the original values.
 APInt mulhu(const APInt &C1, const APInt &C2);
 
 /// Compute the higher order bits of signed multiplication of two APInts

>From f94a27996041edf616ee1f539609e76acbcd83a5 Mon Sep 17 00:00:00 2001
From: Shourya Goel <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 03:01:06 +0530
Subject: [PATCH 17/21] Update llvm/include/llvm/ADT/APInt.h

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 llvm/include/llvm/ADT/APInt.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index b60aec92eaa720..c92fe800af407a 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2200,7 +2200,7 @@ APInt mulhu(const APInt &C1, const APInt &C2);
 
 /// Compute the higher order bits of signed multiplication of two APInts
 /// Mathematically, this is similar to mulhu but for signed values.
-/// Example: mulhs(-2097152,524288) == -256
+/// Example: `mulhs(-2097152, 524288) == -256`
 APInt mulhs(const APInt &C1, const APInt &C2);
 
 /// Compute GCD of two unsigned APInt values.

>From db910c4076c045750e293f0e5ea5d34bc394b9a0 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 03:08:15 +0530
Subject: [PATCH 18/21] Replaced Lamda Function

---
 llvm/unittests/Support/KnownBitsTest.cpp | 4 ++--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp   | 9 +++------
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 65bb228cbc73c3..b4ecfa97989e3b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,13 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
+      &APIntOps::mulhs,
       checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
+      &APIntOps::mulhu,
       checkCorrectnessOnlyBinary);
 }
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index c705051f0f440e..fade231d5418dd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -433,9 +433,8 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          return llvm::APIntOps::mulhs(a, b);
-        });
+        adaptor.getOperands(), &llvm::APIntOps::mulhs(a, b)
+        );
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
@@ -488,9 +487,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          return llvm::APIntOps::mulhu(a, b);
-        });
+        adaptor.getOperands(), &llvm::APIntOps::mulhu(a, b));
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);

>From 58f0613cf42a69b2aac4239eca792b3a27f71007 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 03:08:38 +0530
Subject: [PATCH 19/21] Ran clang Formatter

---
 llvm/unittests/Support/KnownBitsTest.cpp | 6 ++----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp   | 3 +--
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b4ecfa97989e3b..e44775c1de6aaa 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,14 +537,12 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      &APIntOps::mulhs,
-      checkCorrectnessOnlyBinary);
+      &APIntOps::mulhs, checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      &APIntOps::mulhu,
-      checkCorrectnessOnlyBinary);
+      &APIntOps::mulhu, checkCorrectnessOnlyBinary);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fade231d5418dd..2a60c844e4ce51 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -433,8 +433,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(), &llvm::APIntOps::mulhs(a, b)
-        );
+        adaptor.getOperands(), &llvm::APIntOps::mulhs(a, b));
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);

>From 8aee6cd97663d3c166119a9a6425c24d66f32aaa Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 03:19:25 +0530
Subject: [PATCH 20/21] Bringing back lamda functions

---
 llvm/include/llvm/ADT/APInt.h                       | 8 ++++----
 llvm/unittests/Support/KnownBitsTest.cpp            | 6 ++++--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp              | 6 ++++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 4 ++--
 4 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index c92fe800af407a..cbe6c2e91f2d9f 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,14 +2193,14 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
   return A.uge(B) ? (A - B) : (B - A);
 }
 
-/// Compute the higher order bits of unsigned multiplication of two APInts
+/// Compute the higher order bits of unsigned multiplication of two APInts.
 /// Mathematically, this computes the value: `(C1 * C2) >> C2.getBitWidth()`
 /// where `(C1 * C2)` has double the bit width of the original values.
 APInt mulhu(const APInt &C1, const APInt &C2);
 
-/// Compute the higher order bits of signed multiplication of two APInts
-/// Mathematically, this is similar to mulhu but for signed values.
-/// Example: `mulhs(-2097152, 524288) == -256`
+/// Compute the higher order bits of signed multiplication of two APInts.
+/// Mathematically, this is `(C1 * C2) >> C2.getBitWidth()` while preserving
+/// the signed bit. Example: `mulhs(-2097152, 524288) == -256`
 APInt mulhs(const APInt &C1, const APInt &C2);
 
 /// Compute GCD of two unsigned APInt values.
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index e44775c1de6aaa..65bb228cbc73c3 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,12 +537,14 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      &APIntOps::mulhs, checkCorrectnessOnlyBinary);
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
+      checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      &APIntOps::mulhu, checkCorrectnessOnlyBinary);
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
+      checkCorrectnessOnlyBinary);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2a60c844e4ce51..5c52511c50f3f5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -433,7 +433,8 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(), &llvm::APIntOps::mulhs(a, b));
+        adaptor.getOperands(),
+        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhs(a, b); });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
@@ -486,7 +487,8 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(), &llvm::APIntOps::mulhu(a, b));
+        adaptor.getOperands(),
+        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhu(a, b); });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index eb1e97e7ecc908..dd18b6ec7165e2 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -251,9 +251,9 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
     auto highBits = constFoldBinaryOp<IntegerAttr>(
         {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
           if (IsSigned) {
-            return llvm::APIntOps::mulhs(a, b);
+            return llvm::APIntOps::mulhs(a,b);
           } else {
-            return llvm::APIntOps::mulhu(a, b);
+            return llvm::APIntOps::mulhu(a,b);
           }
         });
 

>From a60ffdc3ea9b343133c8739844317a959bbed088 Mon Sep 17 00:00:00 2001
From: Sh0g0-1758 <shouryagoel10000 at gmail.com>
Date: Sun, 10 Mar 2024 03:19:58 +0530
Subject: [PATCH 21/21] Ran Clang Formatter

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp              | 10 ++++++----
 mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp |  4 ++--
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5c52511c50f3f5..c705051f0f440e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -433,8 +433,9 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(),
-        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhs(a, b); });
+        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+          return llvm::APIntOps::mulhs(a, b);
+        });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
@@ -487,8 +488,9 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
           [](const APInt &a, const APInt &b) { return a * b; })) {
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
-        adaptor.getOperands(),
-        [](const APInt &a, const APInt &b) { return llvm::APIntOps::mulhu(a, b); });
+        adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+          return llvm::APIntOps::mulhu(a, b);
+        });
     assert(highAttr && "Unexpected constant-folding failure");
 
     results.push_back(lowAttr);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index dd18b6ec7165e2..eb1e97e7ecc908 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -251,9 +251,9 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
     auto highBits = constFoldBinaryOp<IntegerAttr>(
         {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
           if (IsSigned) {
-            return llvm::APIntOps::mulhs(a,b);
+            return llvm::APIntOps::mulhs(a, b);
           } else {
-            return llvm::APIntOps::mulhu(a,b);
+            return llvm::APIntOps::mulhu(a, b);
           }
         });
 



More information about the llvm-commits mailing list