[llvm] [X86][AVX] Fix handling of out-of-bounds shift amounts in AVX2 vector shift nodes (PR #84426)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 12 00:59:37 PDT 2024


https://github.com/SahilPatidar updated https://github.com/llvm/llvm-project/pull/84426

>From 87562698ac6af2cce161dcd37623a10c315d8085 Mon Sep 17 00:00:00 2001
From: SahilPatidar <patidarsahil2001 at gmail.com>
Date: Thu, 7 Mar 2024 21:26:31 +0530
Subject: [PATCH] [X86][AVX] Fix handling of out-of-bounds shift amounts in
 AVX2 vector shift nodes #83840

---
 llvm/lib/Target/X86/X86ISelLowering.cpp |  11 +
 llvm/test/CodeGen/X86/combine-sra.ll    | 275 ++++++++++++++++++++++++
 2 files changed, 286 insertions(+)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a74901958ac056..bde6b202538477 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47334,6 +47334,17 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
     return V;
 
+  APInt ShiftAmt;
+  SDNode *UMinNode = N1.getNode();
+  if (supportedVectorVarShift(VT, Subtarget, ISD::SRA) &&
+      UMinNode->getOpcode() == ISD::UMIN &&
+      ISD::isConstantSplatVector(UMinNode->getOperand(1).getNode(), ShiftAmt) &&
+      ShiftAmt == VT.getScalarSizeInBits() - 1) {
+    SDValue ShrAmtVal = UMinNode->getOperand(0);
+    SDLoc DL(N);
+    return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
+  }
+
   // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
   // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
   // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
diff --git a/llvm/test/CodeGen/X86/combine-sra.ll b/llvm/test/CodeGen/X86/combine-sra.ll
index 0675ced68d7a7a..57e22d67049476 100644
--- a/llvm/test/CodeGen/X86/combine-sra.ll
+++ b/llvm/test/CodeGen/X86/combine-sra.ll
@@ -521,3 +521,278 @@ define <4 x i32> @combine_vec_ashr_positive_splat(<4 x i32> %x, <4 x i32> %y) {
   %2 = ashr <4 x i32> %1, <i32 10, i32 10, i32 10, i32 10>
   ret <4 x i32> %2
 }
+
+define <8 x i16> @combine_vec16_ashr_out_of_bound(<8 x i16> %x, <8 x i16> %y) {
+; SSE2-LABEL: combine_vec16_ashr_out_of_bound:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psubusw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
+; SSE2-NEXT:    psubw %xmm2, %xmm1
+; SSE2-NEXT:    psllw $12, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psraw $15, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm0, %xmm3
+; SSE2-NEXT:    psraw $8, %xmm0
+; SSE2-NEXT:    pand %xmm2, %xmm0
+; SSE2-NEXT:    por %xmm3, %xmm0
+; SSE2-NEXT:    paddw %xmm1, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psraw $15, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm0, %xmm3
+; SSE2-NEXT:    psraw $4, %xmm0
+; SSE2-NEXT:    pand %xmm2, %xmm0
+; SSE2-NEXT:    por %xmm3, %xmm0
+; SSE2-NEXT:    paddw %xmm1, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    psraw $15, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm0, %xmm3
+; SSE2-NEXT:    psraw $2, %xmm0
+; SSE2-NEXT:    pand %xmm2, %xmm0
+; SSE2-NEXT:    por %xmm3, %xmm0
+; SSE2-NEXT:    paddw %xmm1, %xmm1
+; SSE2-NEXT:    psraw $15, %xmm1
+; SSE2-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-NEXT:    pandn %xmm0, %xmm2
+; SSE2-NEXT:    psraw $1, %xmm0
+; SSE2-NEXT:    pand %xmm1, %xmm0
+; SSE2-NEXT:    por %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec16_ashr_out_of_bound:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    movdqa %xmm0, %xmm2
+; SSE41-NEXT:    pminuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    psllw $12, %xmm0
+; SSE41-NEXT:    psllw $4, %xmm1
+; SSE41-NEXT:    por %xmm1, %xmm0
+; SSE41-NEXT:    movdqa %xmm0, %xmm1
+; SSE41-NEXT:    paddw %xmm0, %xmm1
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $8, %xmm3
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $4, %xmm3
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $2, %xmm3
+; SSE41-NEXT:    paddw %xmm1, %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm3
+; SSE41-NEXT:    psraw $1, %xmm3
+; SSE41-NEXT:    paddw %xmm1, %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    pblendvb %xmm0, %xmm3, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm0
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: combine_vec16_ashr_out_of_bound:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpminuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX2-NEXT:    vpmovsxwd %xmm0, %ymm0
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero
+; AVX2-NEXT:    vpsravd %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
+; AVX2-NEXT:    vpackssdw %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: combine_vec16_ashr_out_of_bound:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpminuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX512-NEXT:    vpsravw %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    retq
+  %1 = tail call <8 x i16> @llvm.umin.v4i16(<8 x i16> %y, <8 x i16> <i16 31, i16 31, i16 31, i16 31, i16 31, i16 31, i16 31, i16 31>)
+  %2 = ashr <8 x i16> %x, %1
+  ret <8 x i16> %2
+}
+
+define <4 x i32> @combine_vec32_ashr_out_of_bound(<4 x i32> %x, <4 x i32> %y) {
+; SSE2-LABEL: combine_vec32_ashr_out_of_bound:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa {{.*#+}} xmm2 = [2147483648,2147483648,2147483648,2147483648]
+; SSE2-NEXT:    pxor %xmm1, %xmm2
+; SSE2-NEXT:    pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pandn %xmm1, %xmm3
+; SSE2-NEXT:    psrld $27, %xmm2
+; SSE2-NEXT:    por %xmm3, %xmm2
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm2[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm3
+; SSE2-NEXT:    psrad %xmm1, %xmm3
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm4 = xmm2[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm1
+; SSE2-NEXT:    psrad %xmm4, %xmm1
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm3[0]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[2,3,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm2[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm4
+; SSE2-NEXT:    psrad %xmm3, %xmm4
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm2 = xmm2[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    psrad %xmm2, %xmm0
+; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm4[1]
+; SSE2-NEXT:    shufps {{.*#+}} xmm1 = xmm1[0,3],xmm0[0,3]
+; SSE2-NEXT:    movaps %xmm1, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec32_ashr_out_of_bound:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    pminud {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm2 = xmm1[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm3
+; SSE41-NEXT:    psrad %xmm2, %xmm3
+; SSE41-NEXT:    pshufd {{.*#+}} xmm2 = xmm1[2,3,2,3]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm4 = xmm2[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm5
+; SSE41-NEXT:    psrad %xmm4, %xmm5
+; SSE41-NEXT:    pblendw {{.*#+}} xmm5 = xmm3[0,1,2,3],xmm5[4,5,6,7]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm3
+; SSE41-NEXT:    psrad %xmm1, %xmm3
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm1 = xmm2[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    psrad %xmm1, %xmm0
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm3[0,1,2,3],xmm0[4,5,6,7]
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm5[2,3],xmm0[4,5],xmm5[6,7]
+; SSE41-NEXT:    retq
+;
+; AVX-LABEL: combine_vec32_ashr_out_of_bound:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsravd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %1 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %y, <4 x i32> <i32 31, i32 31, i32 31, i32 31>)
+  %2 = ashr <4 x i32> %x, %1
+  ret <4 x i32> %2
+}
+
+define <4 x i64> @combine_vec64_ashr_out_of_bound(<4 x i64> %x, <4 x i64> %y) {
+; SSE2-LABEL: combine_vec64_ashr_out_of_bound:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa {{.*#+}} xmm5 = [9223372039002259456,9223372039002259456]
+; SSE2-NEXT:    movdqa %xmm3, %xmm4
+; SSE2-NEXT:    pxor %xmm5, %xmm4
+; SSE2-NEXT:    pshufd {{.*#+}} xmm6 = xmm4[0,0,2,2]
+; SSE2-NEXT:    movdqa {{.*#+}} xmm7 = [2147483679,2147483679,2147483679,2147483679]
+; SSE2-NEXT:    movdqa %xmm7, %xmm8
+; SSE2-NEXT:    pcmpgtd %xmm6, %xmm8
+; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[1,1,3,3]
+; SSE2-NEXT:    pcmpeqd %xmm5, %xmm4
+; SSE2-NEXT:    pand %xmm8, %xmm4
+; SSE2-NEXT:    movdqa {{.*#+}} xmm6 = [31,31]
+; SSE2-NEXT:    pand %xmm4, %xmm3
+; SSE2-NEXT:    pandn %xmm6, %xmm4
+; SSE2-NEXT:    por %xmm3, %xmm4
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    pxor %xmm5, %xmm3
+; SSE2-NEXT:    pshufd {{.*#+}} xmm8 = xmm3[0,0,2,2]
+; SSE2-NEXT:    pcmpgtd %xmm8, %xmm7
+; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm3[1,1,3,3]
+; SSE2-NEXT:    pcmpeqd %xmm5, %xmm3
+; SSE2-NEXT:    pand %xmm7, %xmm3
+; SSE2-NEXT:    pand %xmm3, %xmm2
+; SSE2-NEXT:    pandn %xmm6, %xmm3
+; SSE2-NEXT:    por %xmm2, %xmm3
+; SSE2-NEXT:    movdqa {{.*#+}} xmm2 = [9223372036854775808,9223372036854775808]
+; SSE2-NEXT:    movdqa %xmm2, %xmm5
+; SSE2-NEXT:    psrlq %xmm3, %xmm5
+; SSE2-NEXT:    pshufd {{.*#+}} xmm6 = xmm3[2,3,2,3]
+; SSE2-NEXT:    movdqa %xmm2, %xmm7
+; SSE2-NEXT:    psrlq %xmm6, %xmm7
+; SSE2-NEXT:    movsd {{.*#+}} xmm7 = xmm5[0],xmm7[1]
+; SSE2-NEXT:    movdqa %xmm0, %xmm5
+; SSE2-NEXT:    psrlq %xmm3, %xmm5
+; SSE2-NEXT:    psrlq %xmm6, %xmm0
+; SSE2-NEXT:    movsd {{.*#+}} xmm0 = xmm5[0],xmm0[1]
+; SSE2-NEXT:    xorpd %xmm7, %xmm0
+; SSE2-NEXT:    psubq %xmm7, %xmm0
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    psrlq %xmm4, %xmm3
+; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm4[2,3,2,3]
+; SSE2-NEXT:    psrlq %xmm5, %xmm2
+; SSE2-NEXT:    movsd {{.*#+}} xmm2 = xmm3[0],xmm2[1]
+; SSE2-NEXT:    movdqa %xmm1, %xmm3
+; SSE2-NEXT:    psrlq %xmm4, %xmm3
+; SSE2-NEXT:    psrlq %xmm5, %xmm1
+; SSE2-NEXT:    movsd {{.*#+}} xmm1 = xmm3[0],xmm1[1]
+; SSE2-NEXT:    xorpd %xmm2, %xmm1
+; SSE2-NEXT:    psubq %xmm2, %xmm1
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec64_ashr_out_of_bound:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    movdqa {{.*#+}} xmm7 = [9223372039002259456,9223372039002259456]
+; SSE41-NEXT:    movdqa %xmm3, %xmm0
+; SSE41-NEXT:    pxor %xmm7, %xmm0
+; SSE41-NEXT:    movdqa {{.*#+}} xmm8 = [9223372039002259487,9223372039002259487]
+; SSE41-NEXT:    movdqa %xmm8, %xmm6
+; SSE41-NEXT:    pcmpeqd %xmm0, %xmm6
+; SSE41-NEXT:    pshufd {{.*#+}} xmm9 = xmm0[0,0,2,2]
+; SSE41-NEXT:    movdqa {{.*#+}} xmm5 = [2147483679,2147483679,2147483679,2147483679]
+; SSE41-NEXT:    movdqa %xmm5, %xmm0
+; SSE41-NEXT:    pcmpgtd %xmm9, %xmm0
+; SSE41-NEXT:    pand %xmm6, %xmm0
+; SSE41-NEXT:    movapd {{.*#+}} xmm9 = [31,31]
+; SSE41-NEXT:    movapd %xmm9, %xmm6
+; SSE41-NEXT:    blendvpd %xmm0, %xmm3, %xmm6
+; SSE41-NEXT:    pxor %xmm2, %xmm7
+; SSE41-NEXT:    pcmpeqd %xmm7, %xmm8
+; SSE41-NEXT:    pshufd {{.*#+}} xmm0 = xmm7[0,0,2,2]
+; SSE41-NEXT:    pcmpgtd %xmm0, %xmm5
+; SSE41-NEXT:    pand %xmm8, %xmm5
+; SSE41-NEXT:    movdqa %xmm5, %xmm0
+; SSE41-NEXT:    blendvpd %xmm0, %xmm2, %xmm9
+; SSE41-NEXT:    movdqa {{.*#+}} xmm0 = [9223372036854775808,9223372036854775808]
+; SSE41-NEXT:    movdqa %xmm0, %xmm2
+; SSE41-NEXT:    psrlq %xmm9, %xmm2
+; SSE41-NEXT:    pshufd {{.*#+}} xmm3 = xmm9[2,3,2,3]
+; SSE41-NEXT:    movdqa %xmm0, %xmm5
+; SSE41-NEXT:    psrlq %xmm3, %xmm5
+; SSE41-NEXT:    pblendw {{.*#+}} xmm5 = xmm2[0,1,2,3],xmm5[4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm4, %xmm2
+; SSE41-NEXT:    psrlq %xmm9, %xmm2
+; SSE41-NEXT:    psrlq %xmm3, %xmm4
+; SSE41-NEXT:    pblendw {{.*#+}} xmm4 = xmm2[0,1,2,3],xmm4[4,5,6,7]
+; SSE41-NEXT:    pxor %xmm5, %xmm4
+; SSE41-NEXT:    psubq %xmm5, %xmm4
+; SSE41-NEXT:    movdqa %xmm0, %xmm2
+; SSE41-NEXT:    psrlq %xmm6, %xmm2
+; SSE41-NEXT:    pshufd {{.*#+}} xmm3 = xmm6[2,3,2,3]
+; SSE41-NEXT:    psrlq %xmm3, %xmm0
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm2[0,1,2,3],xmm0[4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm1, %xmm2
+; SSE41-NEXT:    psrlq %xmm6, %xmm2
+; SSE41-NEXT:    psrlq %xmm3, %xmm1
+; SSE41-NEXT:    pblendw {{.*#+}} xmm1 = xmm2[0,1,2,3],xmm1[4,5,6,7]
+; SSE41-NEXT:    pxor %xmm0, %xmm1
+; SSE41-NEXT:    psubq %xmm0, %xmm1
+; SSE41-NEXT:    movdqa %xmm4, %xmm0
+; SSE41-NEXT:    retq
+;
+; AVX2-LABEL: combine_vec64_ashr_out_of_bound:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpbroadcastq {{.*#+}} ymm2 = [9223372036854775808,9223372036854775808,9223372036854775808,9223372036854775808]
+; AVX2-NEXT:    vpxor %ymm2, %ymm1, %ymm3
+; AVX2-NEXT:    vpbroadcastq {{.*#+}} ymm4 = [9223372036854775838,9223372036854775838,9223372036854775838,9223372036854775838]
+; AVX2-NEXT:    vpcmpgtq %ymm4, %ymm3, %ymm3
+; AVX2-NEXT:    vbroadcastsd {{.*#+}} ymm4 = [31,31,31,31]
+; AVX2-NEXT:    vblendvpd %ymm3, %ymm4, %ymm1, %ymm1
+; AVX2-NEXT:    vpsrlvq %ymm1, %ymm2, %ymm2
+; AVX2-NEXT:    vpsrlvq %ymm1, %ymm0, %ymm0
+; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpsubq %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: combine_vec64_ashr_out_of_bound:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpminuq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to4}, %ymm1, %ymm1
+; AVX512-NEXT:    vpsravq %ymm1, %ymm0, %ymm0
+; AVX512-NEXT:    retq
+  %1 = tail call <4 x i64> @llvm.umin.v4i64(<4 x i64> %y, <4 x i64> <i64 31, i64 31, i64 31, i64 31>)
+  %2 = ashr <4 x i64> %x, %1
+  ret <4 x i64> %2
+}



More information about the llvm-commits mailing list