[llvm] [X86][AVX] Fix handling of out-of-bounds shift amounts in AVX2 vector logical shift nodes #83840 (PR #86922)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 5 03:10:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: None (SahilPatidar)

<details>
<summary>Changes</summary>

Resolve #<!-- -->83840

---
Full diff: https://github.com/llvm/llvm-project/pull/86922.diff


2 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+29) 
- (modified) llvm/test/CodeGen/X86/combine-srl.ll (+116) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b9a87f9024c7de..d36b96bd661038 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45566,6 +45566,19 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  if (N->getOpcode() == ISD::VSELECT && LHS.getOpcode() == ISD::SRL &&
+      supportedVectorVarShift(VT, Subtarget, ISD::SRL)) {
+    APInt SV;
+    if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == LHS.getOperand(1) &&
+        cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
+        ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
+        ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
+        SV == VT.getScalarSizeInBits()) {
+      SDLoc DL(LHS);
+      return DAG.getNode(X86ISD::VSRLV, DL, LHS->getVTList(), LHS.getOperand(0), LHS.getOperand(1));
+    }
+  }
+
   // Early exit check
   if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
     return SDValue();
@@ -47443,6 +47456,22 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
     return V;
 
+  if (N0.getOpcode() == ISD::VSELECT &&
+      supportedVectorVarShift(VT, Subtarget, ISD::SRL)) {
+    SDValue Cond = N0.getOperand(0);
+    SDValue N00 = N0.getOperand(1);
+    SDValue N01 = N0.getOperand(2);
+    APInt SV;
+    if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
+        cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
+        ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
+        ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
+        SV == VT.getScalarSizeInBits()) {
+      SDLoc DL(N);
+      return DAG.getNode(X86ISD::VSRLV, DL, N->getVTList(), N00, N1);
+    }
+  }
+
   // Only do this on the last DAG combine as it can interfere with other
   // combines.
   if (!DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/X86/combine-srl.ll b/llvm/test/CodeGen/X86/combine-srl.ll
index 33649e6d87b915..eeeff2f8eb25bd 100644
--- a/llvm/test/CodeGen/X86/combine-srl.ll
+++ b/llvm/test/CodeGen/X86/combine-srl.ll
@@ -606,3 +606,119 @@ define <4 x i32> @combine_vec_lshr_trunc_and(<4 x i32> %x, <4 x i64> %y) {
   %3 = lshr <4 x i32> %x, %2
   ret <4 x i32> %3
 }
+
+define <4 x i32> @combine_vec_lshr_clamped1(<4 x i32> %sh, <4 x i32> %amt) {
+; SSE2-LABEL: combine_vec_lshr_clamped1:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa {{.*#+}} xmm2 = [2147483648,2147483648,2147483648,2147483648]
+; SSE2-NEXT:    pxor %xmm1, %xmm2
+; SSE2-NEXT:    pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm1[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm4
+; SSE2-NEXT:    psrld %xmm3, %xmm4
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm1[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm5
+; SSE2-NEXT:    psrld %xmm3, %xmm5
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm5 = xmm5[0],xmm4[0]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm1[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm4
+; SSE2-NEXT:    psrld %xmm3, %xmm4
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    psrld %xmm1, %xmm0
+; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm4[1]
+; SSE2-NEXT:    shufps {{.*#+}} xmm5 = xmm5[0,3],xmm0[0,3]
+; SSE2-NEXT:    pandn %xmm5, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec_lshr_clamped1:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    pmovsxbd {{.*#+}} xmm2 = [31,31,31,31]
+; SSE41-NEXT:    pminud %xmm1, %xmm2
+; SSE41-NEXT:    pcmpeqd %xmm1, %xmm2
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm3 = xmm1[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    psrld %xmm3, %xmm4
+; SSE41-NEXT:    pshufd {{.*#+}} xmm3 = xmm1[2,3,2,3]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm5 = xmm3[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm6
+; SSE41-NEXT:    psrld %xmm5, %xmm6
+; SSE41-NEXT:    pblendw {{.*#+}} xmm6 = xmm4[0,1,2,3],xmm6[4,5,6,7]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    psrld %xmm1, %xmm4
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm1 = xmm3[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    psrld %xmm1, %xmm0
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm4[0,1,2,3],xmm0[4,5,6,7]
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm6[2,3],xmm0[4,5],xmm6[6,7]
+; SSE41-NEXT:    pand %xmm2, %xmm0
+; SSE41-NEXT:    retq
+;
+; AVX-LABEL: combine_vec_lshr_clamped1:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsrlvd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %cmp.i = icmp ult <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
+  %shr = lshr <4 x i32> %sh, %amt
+  %1 = select <4 x i1> %cmp.i, <4 x i32> %shr, <4 x i32> zeroinitializer
+  ret <4 x i32> %1
+}
+
+define <4 x i32> @combine_vec_lshr_clamped2(<4 x i32> %sh, <4 x i32> %amt) {
+; SSE2-LABEL: combine_vec_lshr_clamped2:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa {{.*#+}} xmm2 = [2147483648,2147483648,2147483648,2147483648]
+; SSE2-NEXT:    pxor %xmm1, %xmm2
+; SSE2-NEXT:    pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
+; SSE2-NEXT:    pandn %xmm0, %xmm2
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm0[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm2, %xmm4
+; SSE2-NEXT:    psrld %xmm3, %xmm4
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm0[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm2, %xmm3
+; SSE2-NEXT:    psrld %xmm0, %xmm3
+; SSE2-NEXT:    punpckhqdq {{.*#+}} xmm3 = xmm3[1],xmm4[1]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm1[2,3,3,3,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm2, %xmm4
+; SSE2-NEXT:    psrld %xmm0, %xmm4
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm1[0,1,1,1,4,5,6,7]
+; SSE2-NEXT:    psrld %xmm0, %xmm2
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm4[0]
+; SSE2-NEXT:    shufps {{.*#+}} xmm2 = xmm2[0,3],xmm3[0,3]
+; SSE2-NEXT:    movaps %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: combine_vec_lshr_clamped2:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    pmovsxbd {{.*#+}} xmm2 = [31,31,31,31]
+; SSE41-NEXT:    pminud %xmm1, %xmm2
+; SSE41-NEXT:    pcmpeqd %xmm1, %xmm2
+; SSE41-NEXT:    pand %xmm2, %xmm0
+; SSE41-NEXT:    pshufd {{.*#+}} xmm2 = xmm1[2,3,2,3]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm3 = xmm2[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm4
+; SSE41-NEXT:    psrld %xmm3, %xmm4
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm3 = xmm1[2,3,3,3,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm5
+; SSE41-NEXT:    psrld %xmm3, %xmm5
+; SSE41-NEXT:    pblendw {{.*#+}} xmm5 = xmm5[0,1,2,3],xmm4[4,5,6,7]
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm2 = xmm2[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    movdqa %xmm0, %xmm3
+; SSE41-NEXT:    psrld %xmm2, %xmm3
+; SSE41-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[0,1,1,1,4,5,6,7]
+; SSE41-NEXT:    psrld %xmm1, %xmm0
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm3[4,5,6,7]
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm5[2,3],xmm0[4,5],xmm5[6,7]
+; SSE41-NEXT:    retq
+;
+; AVX-LABEL: combine_vec_lshr_clamped2:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsrlvd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %cmp.i = icmp ult <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
+  %1 = select <4 x i1> %cmp.i, <4 x i32> %sh, <4 x i32> zeroinitializer
+  %shr = lshr <4 x i32> %1, %amt
+  ret <4 x i32> %shr
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/86922


More information about the llvm-commits mailing list