[llvm] [X86] `combinePMULH` - combine `mulhu` + `srl` (PR #132548)

Abhishek Kaushik via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 22 05:39:11 PDT 2025


https://github.com/abhishek-kaushik22 created https://github.com/llvm/llvm-project/pull/132548

Fixes #132166

>From df60a87bc176465ef22d8f98d71b540885519a96 Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Sat, 22 Mar 2025 18:06:39 +0530
Subject: [PATCH] [X86] combinePMULH - combine mulhu + srl

Fixes #132166
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 18 ++++++++++++++----
 llvm/test/CodeGen/X86/pmulh.ll          | 20 ++++++++++++++++++++
 2 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 02398923ebc90..ec0af8d53b76e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -54021,7 +54021,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
 }
 
 // Try to form a MULHU or MULHS node by looking for
-// (trunc (srl (mul ext, ext), 16))
+// (trunc (srl (mul ext, ext), >= 16))
 // TODO: This is X86 specific because we want to be able to handle wide types
 // before type legalization. But we can only do it if the vector will be
 // legalized via widening/splitting. Type legalization can't handle promotion
@@ -54046,10 +54046,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
 
   // First instruction should be a right shift by 16 of a multiply.
   SDValue LHS, RHS;
+  APInt ShiftAmt;
   if (!sd_match(Src,
-                m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16))))
+                m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
+    return SDValue();
+
+  if (ShiftAmt.ult(16))
     return SDValue();
 
+  APInt AdditionalShift = (ShiftAmt - 16).trunc(16);
+
   // Count leading sign/zero bits on both inputs - if there are enough then
   // truncation back to vXi16 will be cheap - either as a pack/shuffle
   // sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54087,7 +54093,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
                                 InVT.getSizeInBits() / 16);
     SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
                               DAG.getBitcast(BCVT, RHS));
-    return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+    Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+    return DAG.getNode(ISD::SRL, DL, VT, Res,
+                       DAG.getConstant(AdditionalShift, DL, VT));
   }
 
   // Truncate back to source type.
@@ -54095,7 +54103,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
   RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
 
   unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
-  return DAG.getNode(Opc, DL, VT, LHS, RHS);
+  SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
+  return DAG.getNode(ISD::SRL, DL, VT, Res,
+                     DAG.getConstant(AdditionalShift, DL, VT));
 }
 
 // Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
diff --git a/llvm/test/CodeGen/X86/pmulh.ll b/llvm/test/CodeGen/X86/pmulh.ll
index 300da68d9a3b3..8ecc3c1575367 100644
--- a/llvm/test/CodeGen/X86/pmulh.ll
+++ b/llvm/test/CodeGen/X86/pmulh.ll
@@ -2166,3 +2166,23 @@ define <8 x i16> @sse2_pmulhu_w_const(<8 x i16> %a0, <8 x i16> %a1) {
 }
 declare <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16>, <8 x i16>)
 
+define <8 x i16> @mul_and_shift17(<8 x i16> %a, <8 x i16> %b) {
+; SSE-LABEL: mul_and_shift17:
+; SSE:       # %bb.0:
+; SSE-NEXT:    pmulhuw %xmm1, %xmm0
+; SSE-NEXT:    psrlw $1, %xmm0
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: mul_and_shift17:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmulhuw %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vpsrlw $1, %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %a.ext = zext <8 x i16> %a to <8 x i32>
+  %b.ext = zext <8 x i16> %b to <8 x i32>
+  %mul = mul <8 x i32> %a.ext, %b.ext
+  %shift = lshr <8 x i32> %mul, splat(i32 17)
+  %trunc = trunc <8 x i32> %shift to <8 x i16>
+  ret <8 x i16> %trunc
+}
+



More information about the llvm-commits mailing list