[llvm] [mlir] [ADT] Add signed and unsigned mulh to APInt (PR #84719)

Atousa Duprat via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 14 22:05:46 PDT 2024


https://github.com/Atousa updated https://github.com/llvm/llvm-project/pull/84719

>From 362caa1cb7c8bb290d0e1062290533cdd30d17d8 Mon Sep 17 00:00:00 2001
From: Atousa Duprat <atousa.p at gmail.com>
Date: Thu, 7 Mar 2024 19:22:51 -0800
Subject: [PATCH 1/2] [ADT] Add implementations for avgFloor and avgCeil to
 APInt

Supports both signed and unsigned expansions.
SelectionDAG now calls the APInt implementation of these functions.
---
 llvm/include/llvm/ADT/APInt.h                 | 12 +++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 32 ++-----
 llvm/lib/Support/APInt.cpp                    | 36 ++++++++
 llvm/unittests/ADT/APIntTest.cpp              | 86 +++++++++++++++++++
 4 files changed, 142 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..3a58fff3be1d4c 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
   return A.uge(B) ? (A - B) : (B - A);
 }
 
+/// Compute the floor of the signed average of C1 and C2
+APInt avgFloorS(const APInt &C1, const APInt &C2);
+
+/// Compute the floor of the unsigned average of C1 and C2
+APInt avgFloorU(const APInt &C1, const APInt &C2);
+
+/// Compute the ceil of the signed average of C1 and C2
+APInt avgCeilS(const APInt &C1, const APInt &C2);
+
+/// Compute the ceil of the unsigned average of C1 and C2
+APInt avgCeilU(const APInt &C1, const APInt &C2);
+
 /// Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..a844a00be16291 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6021,30 +6021,14 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     APInt C2Ext = C2.zext(FullWidth);
     return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
   }
-  case ISD::AVGFLOORS: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.sext(FullWidth);
-    APInt C2Ext = C2.sext(FullWidth);
-    return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
-  }
-  case ISD::AVGFLOORU: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.zext(FullWidth);
-    APInt C2Ext = C2.zext(FullWidth);
-    return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
-  }
-  case ISD::AVGCEILS: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.sext(FullWidth);
-    APInt C2Ext = C2.sext(FullWidth);
-    return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
-  }
-  case ISD::AVGCEILU: {
-    unsigned FullWidth = C1.getBitWidth() + 1;
-    APInt C1Ext = C1.zext(FullWidth);
-    APInt C2Ext = C2.zext(FullWidth);
-    return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
-  }
+  case ISD::AVGFLOORS:
+    return APIntOps::avgFloorS(C1, C2);
+  case ISD::AVGFLOORU:
+    return APIntOps::avgFloorU(C1, C2);
+  case ISD::AVGCEILS:
+    return APIntOps::avgCeilS(C1, C2);
+  case ISD::AVGCEILU:
+    return APIntOps::avgCeilU(C1, C2);
   case ISD::ABDS:
     return APIntOps::smax(C1, C2) - APIntOps::smin(C1, C2);
   case ISD::ABDU:
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..7053f3b87682f8 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3094,3 +3094,39 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
     memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
   }
 }
+
+APInt APIntOps::avgFloorS(const APInt &C1, const APInt &C2) {
+  // Return floor((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.sext(FullWidth);
+  APInt C2Ext = C2.sext(FullWidth);
+  return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
+}
+
+APInt APIntOps::avgFloorU(const APInt &C1, const APInt &C2) {
+  // Return floor((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.zext(FullWidth);
+  APInt C2Ext = C2.zext(FullWidth);
+  return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
+}
+
+APInt APIntOps::avgCeilS(const APInt &C1, const APInt &C2) {
+  // Return ceil((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.sext(FullWidth);
+  APInt C2Ext = C2.sext(FullWidth);
+  return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
+}
+
+APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
+  // Return ceil((C1 + C2)/2)
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() + 1;
+  APInt C1Ext = C1.zext(FullWidth);
+  APInt C2Ext = C2.zext(FullWidth);
+  return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
+}
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..3e49b5876e0b83 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -14,6 +14,7 @@
 #include "llvm/Support/Alignment.h"
 #include "gtest/gtest.h"
 #include <array>
+#include <climits>
 #include <optional>
 
 using namespace llvm;
@@ -2877,6 +2878,91 @@ TEST(APIntTest, RoundingSDiv) {
   }
 }
 
+TEST(APIntTest, Average) {
+  APInt A0(32, 0);
+  APInt A2(32, 2);
+  APInt A100(32, 100);
+  APInt A101(32, 101);
+  APInt A200(32, 200, false);
+  APInt ApUMax(32, UINT_MAX, false);
+
+  EXPECT_EQ(APInt(32, 150), APIntOps::avgFloorU(A100, A200));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorU(A100, A200));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilU(A100, A200));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorU(A100, A101));
+  EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilU(A100, A101));
+  EXPECT_EQ(A0, APIntOps::avgFloorU(A0, A0));
+  EXPECT_EQ(A0, APIntOps::avgCeilU(A0, A0));
+  EXPECT_EQ(ApUMax, APIntOps::avgFloorU(ApUMax, ApUMax));
+  EXPECT_EQ(ApUMax, APIntOps::avgCeilU(ApUMax, ApUMax));
+  EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorU(A0, ApUMax));
+  EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilU(A0, ApUMax));
+
+  APInt Ap100(32, +100);
+  APInt Ap101(32, +101);
+  APInt Ap200(32, +200);
+  APInt Am1(32, -1);
+  APInt Am100(32, -100);
+  APInt Am101(32, -101);
+  APInt Am200(32, -200);
+  APInt AmSMin(32, INT_MIN);
+  APInt ApSMax(32, INT_MAX);
+
+  EXPECT_EQ(APInt(32, +150), APIntOps::avgFloorS(Ap100, Ap200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Ap100, Ap200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Ap100, Ap200));
+
+  EXPECT_EQ(APInt(32, -150), APIntOps::avgFloorS(Am100, Am200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Am100, Am200));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Am100, Am200));
+
+  EXPECT_EQ(APInt(32, +100), APIntOps::avgFloorS(Ap100, Ap101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Ap100, Ap101));
+  EXPECT_EQ(APInt(32, +101), APIntOps::avgCeilS(Ap100, Ap101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Ap100, Ap101));
+
+  EXPECT_EQ(APInt(32, -101), APIntOps::avgFloorS(Am100, Am101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(Am100, Am101));
+  EXPECT_EQ(APInt(32, -100), APIntOps::avgCeilS(Am100, Am101));
+  EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(Am100, Am101));
+
+  EXPECT_EQ(AmSMin, APIntOps::avgFloorS(AmSMin, AmSMin));
+  EXPECT_EQ(AmSMin, APIntOps::avgCeilS(AmSMin, AmSMin));
+
+  EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(A0, AmSMin));
+  EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(A0, AmSMin));
+
+  EXPECT_EQ(A0, APIntOps::avgFloorS(A0, A0));
+  EXPECT_EQ(A0, APIntOps::avgCeilS(A0, A0));
+
+  EXPECT_EQ(Am1, APIntOps::avgFloorS(AmSMin, ApSMax));
+  EXPECT_EQ(A0, APIntOps::avgCeilS(AmSMin, ApSMax));
+
+  EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::DOWN),
+            APIntOps::avgFloorS(A0, ApSMax));
+  EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::UP),
+            APIntOps::avgCeilS(A0, ApSMax));
+
+  EXPECT_EQ(ApSMax, APIntOps::avgFloorS(ApSMax, ApSMax));
+  EXPECT_EQ(ApSMax, APIntOps::avgCeilS(ApSMax, ApSMax));
+}
+
 TEST(APIntTest, umul_ov) {
   const std::pair<uint64_t, uint64_t> Overflows[] = {
       {0x8000000000000000, 2},

>From 3a2283aaf18d748a88d49f26ec7974ba02c37a7b Mon Sep 17 00:00:00 2001
From: Atousa Duprat <atousa.p at gmail.com>
Date: Sun, 10 Mar 2024 21:53:04 -0700
Subject: [PATCH 2/2] [ADT] Add signed and unsigned mulh to APInt

This addresses issue #84207
---
 llvm/include/llvm/ADT/APInt.h                 |  8 +++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 23 ++------
 llvm/lib/Support/APInt.cpp                    | 17 ++++++
 llvm/unittests/ADT/APIntTest.cpp              | 52 +++++++++++++++++++
 llvm/unittests/Support/KnownBitsTest.cpp      | 10 +---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  8 +--
 .../SPIRV/IR/SPIRVCanonicalization.cpp        |  7 +--
 7 files changed, 87 insertions(+), 38 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 3a58fff3be1d4c..5f51b1e0d6ece2 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2205,6 +2205,14 @@ APInt avgCeilS(const APInt &C1, const APInt &C2);
 /// Compute the ceil of the unsigned average of C1 and C2
 APInt avgCeilU(const APInt &C1, const APInt &C2);
 
+/// Performs (2*N)-bit multiplication on sign-extended operands.
+/// Returns the high N bits of the multiplication result.
+APInt mulhs(const APInt &C1, const APInt &C2);
+
+/// Performs (2*N)-bit multiplication on zero-extended operands.
+/// Returns the high N bits of the multiplication result.
+APInt mulhu(const APInt &C1, const APInt &C2);
+
 /// Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index a844a00be16291..43c3765b8ac371 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6009,18 +6009,6 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     if (!C2.getBoolValue())
       break;
     return C1.srem(C2);
-  case ISD::MULHS: {
-    unsigned FullWidth = C1.getBitWidth() * 2;
-    APInt C1Ext = C1.sext(FullWidth);
-    APInt C2Ext = C2.sext(FullWidth);
-    return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
-  }
-  case ISD::MULHU: {
-    unsigned FullWidth = C1.getBitWidth() * 2;
-    APInt C1Ext = C1.zext(FullWidth);
-    APInt C2Ext = C2.zext(FullWidth);
-    return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
-  }
   case ISD::AVGFLOORS:
     return APIntOps::avgFloorS(C1, C2);
   case ISD::AVGFLOORU:
@@ -6029,13 +6017,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
     return APIntOps::avgCeilS(C1, C2);
   case ISD::AVGCEILU:
     return APIntOps::avgCeilU(C1, C2);
-  case ISD::ABDS:
-    return APIntOps::smax(C1, C2) - APIntOps::smin(C1, C2);
-  case ISD::ABDU:
-    return APIntOps::umax(C1, C2) - APIntOps::umin(C1, C2);
-  }
-  return std::nullopt;
-}
+  case ISD::MULHS:
+    return APIntOps::mulhs(C1, C2);
+  case ISD::MULHU:
+    return APIntOps::mulhu(C1, C2);
 
 // Handle constant folding with UNDEF.
 // TODO: Handle more cases.
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 7053f3b87682f8..3678c08ce71d18 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3130,3 +3130,20 @@ APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
   APInt C2Ext = C2.zext(FullWidth);
   return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
 }
+
+APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() * 2;
+  APInt C1Ext = C1.sext(FullWidth);
+  APInt C2Ext = C2.sext(FullWidth);
+  return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
+  assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
+  unsigned FullWidth = C1.getBitWidth() * 2;
+  APInt C1Ext = C1.zext(FullWidth);
+  APInt C2Ext = C2.zext(FullWidth);
+  return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 3e49b5876e0b83..8d0a225812c681 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2806,6 +2806,58 @@ TEST(APIntTest, multiply) {
   EXPECT_EQ(64U, i96.countr_zero());
 }
 
+TEST(APIntOpsTest, Mulh) {
+
+  // Unsigned
+
+  // 32 bits
+  APInt i32a(32, 0x0001'E235);
+  APInt i32b(32, 0xF623'55AD);
+  EXPECT_EQ(0x0001'CFA1, APIntOps::mulhu(i32a, i32b));
+
+  // 64 bits
+  APInt i64a(64, 0x1234'5678'90AB'CDEF);
+  APInt i64b(64, 0xFEDC'BA09'8765'4321);
+  EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulhu(i64a, i64b));
+
+  // 128 bits
+  APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
+  APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
+  APInt i128Res = APIntOps::mulhu(i128a, i128b);
+  EXPECT_EQ(APInt(128, "121FA000A3723A57E68984312C3A8D7E", 16), i128Res);
+
+  // Signed
+
+  // 32 bits
+  APInt i32c(32, 0x1234'5678); // +ve
+  APInt i32d(32, 0x10AB'CDEF); // +ve
+  APInt i32e(32, 0xFEDC'BA09); // -ve
+
+  EXPECT_EQ(0x012F'7D02, APIntOps::mulhs(i32c, i32d));
+  EXPECT_EQ(0xFFEB'4988, APIntOps::mulhs(i32c, i32e));
+  EXPECT_EQ(0x0001'4B68, APIntOps::mulhs(i32e, i32e));
+
+  // 64 bits
+  APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
+  APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
+  APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
+
+  EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulhs(i64c, i64d));
+  EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulhs(i64c, i64e));
+  EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulhs(i64e, i64e));
+
+  // 128 bits
+  APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
+  APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
+  APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
+
+  i128Res = APIntOps::mulhs(i128c, i128d);
+  EXPECT_EQ(APInt(128, "14B66DC328E10C1FE303DF9EA0B2529", 16), i128Res);
+
+  i128Res = APIntOps::mulhs(i128c, i128e);
+  EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res);
+}
+
 TEST(APIntTest, RoundingUDiv) {
   for (uint64_t Ai = 1; Ai <= 255; Ai++) {
     APInt A(8, Ai);
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..65bb228cbc73c3 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -537,19 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhs(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) {
-        unsigned Bits = N1.getBitWidth();
-        return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
-      },
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
       checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::mulhu(Known1, Known2);
       },
-      [](const APInt &N1, const APInt &N2) {
-        unsigned Bits = N1.getBitWidth();
-        return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
-      },
+      [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       checkCorrectnessOnlyBinary);
 }
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..c705051f0f440e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
         adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          unsigned bitWidth = a.getBitWidth();
-          APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
-          return fullProduct.extractBits(bitWidth, bitWidth);
+          return llvm::APIntOps::mulhs(a, b);
         });
     assert(highAttr && "Unexpected constant-folding failure");
 
@@ -491,9 +489,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
     // Invoke the constant fold helper again to calculate the 'high' result.
     Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
         adaptor.getOperands(), [](const APInt &a, const APInt &b) {
-          unsigned bitWidth = a.getBitWidth();
-          APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
-          return fullProduct.extractBits(bitWidth, bitWidth);
+          return llvm::APIntOps::mulhu(a, b);
         });
     assert(highAttr && "Unexpected constant-folding failure");
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4c62289a1e9458..eb1e97e7ecc908 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
 
     auto highBits = constFoldBinaryOp<IntegerAttr>(
         {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
-          unsigned bitWidth = a.getBitWidth();
-          APInt c;
           if (IsSigned) {
-            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+            return llvm::APIntOps::mulhs(a, b);
           } else {
-            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+            return llvm::APIntOps::mulhu(a, b);
           }
-          return c.extractBits(bitWidth, bitWidth); // Extract high result
         });
 
     if (!highBits)



More information about the llvm-commits mailing list