[llvm] [X86] `combinePMULH` - combine `mulhu` + `srl` (PR #132548)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 22 05:39:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: Abhishek Kaushik (abhishek-kaushik22)
<details>
<summary>Changes</summary>
Fixes #<!-- -->132166
---
Full diff: https://github.com/llvm/llvm-project/pull/132548.diff
2 Files Affected:
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+14-4)
- (modified) llvm/test/CodeGen/X86/pmulh.ll (+20)
``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 02398923ebc90..ec0af8d53b76e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -54021,7 +54021,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
}
// Try to form a MULHU or MULHS node by looking for
-// (trunc (srl (mul ext, ext), 16))
+// (trunc (srl (mul ext, ext), >= 16))
// TODO: This is X86 specific because we want to be able to handle wide types
// before type legalization. But we can only do it if the vector will be
// legalized via widening/splitting. Type legalization can't handle promotion
@@ -54046,10 +54046,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
// First instruction should be a right shift by 16 of a multiply.
SDValue LHS, RHS;
+ APInt ShiftAmt;
if (!sd_match(Src,
- m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16))))
+ m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
+ return SDValue();
+
+ if (ShiftAmt.ult(16))
return SDValue();
+ APInt AdditionalShift = (ShiftAmt - 16).trunc(16);
+
// Count leading sign/zero bits on both inputs - if there are enough then
// truncation back to vXi16 will be cheap - either as a pack/shuffle
// sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54087,7 +54093,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
InVT.getSizeInBits() / 16);
SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
DAG.getBitcast(BCVT, RHS));
- return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+ Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+ return DAG.getNode(ISD::SRL, DL, VT, Res,
+ DAG.getConstant(AdditionalShift, DL, VT));
}
// Truncate back to source type.
@@ -54095,7 +54103,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
- return DAG.getNode(Opc, DL, VT, LHS, RHS);
+ SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
+ return DAG.getNode(ISD::SRL, DL, VT, Res,
+ DAG.getConstant(AdditionalShift, DL, VT));
}
// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
diff --git a/llvm/test/CodeGen/X86/pmulh.ll b/llvm/test/CodeGen/X86/pmulh.ll
index 300da68d9a3b3..8ecc3c1575367 100644
--- a/llvm/test/CodeGen/X86/pmulh.ll
+++ b/llvm/test/CodeGen/X86/pmulh.ll
@@ -2166,3 +2166,23 @@ define <8 x i16> @sse2_pmulhu_w_const(<8 x i16> %a0, <8 x i16> %a1) {
}
declare <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16>, <8 x i16>)
+define <8 x i16> @mul_and_shift17(<8 x i16> %a, <8 x i16> %b) {
+; SSE-LABEL: mul_and_shift17:
+; SSE: # %bb.0:
+; SSE-NEXT: pmulhuw %xmm1, %xmm0
+; SSE-NEXT: psrlw $1, %xmm0
+; SSE-NEXT: retq
+;
+; AVX-LABEL: mul_and_shift17:
+; AVX: # %bb.0:
+; AVX-NEXT: vpmulhuw %xmm1, %xmm0, %xmm0
+; AVX-NEXT: vpsrlw $1, %xmm0, %xmm0
+; AVX-NEXT: retq
+ %a.ext = zext <8 x i16> %a to <8 x i32>
+ %b.ext = zext <8 x i16> %b to <8 x i32>
+ %mul = mul <8 x i32> %a.ext, %b.ext
+ %shift = lshr <8 x i32> %mul, splat(i32 17)
+ %trunc = trunc <8 x i32> %shift to <8 x i16>
+ ret <8 x i16> %trunc
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/132548
More information about the llvm-commits
mailing list