[llvm] b320d37 - [X86] Add handling for select(icmp_uge(amt,BW),0,shift_logical(x,amt)) -> avx2 shift(x,amt)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 15 08:34:39 PDT 2024


Author: Simon Pilgrim
Date: 2024-07-15T16:34:21+01:00
New Revision: b320d3733dfb76c1b7d78fc499490d34b99e2284

URL: https://github.com/llvm/llvm-project/commit/b320d3733dfb76c1b7d78fc499490d34b99e2284
DIFF: https://github.com/llvm/llvm-project/commit/b320d3733dfb76c1b7d78fc499490d34b99e2284.diff

LOG: [X86] Add handling for select(icmp_uge(amt,BW),0,shift_logical(x,amt)) -> avx2 shift(x,amt)

We need to catch this otherwise pre-AVX512 targets will fold this to and(icmp_ult(amt,BW),shift_logical(x,amt))

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-shl.ll
    llvm/test/CodeGen/X86/combine-srl.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 91a5526a82bbe..93876fc0876dc 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -46190,13 +46190,13 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
   // to bitwidth-1 for unsigned shifts, effectively performing a maximum left
   // shift of bitwidth-1 positions. and returns zero for unsigned right shifts
   // exceeding bitwidth-1.
-  if (N->getOpcode() == ISD::VSELECT &&
-      (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
-      supportedVectorVarShift(VT, Subtarget, LHS.getOpcode())) {
+  if (N->getOpcode() == ISD::VSELECT) {
     using namespace llvm::SDPatternMatch;
     // fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt)
     // fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt)
-    if (ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
+    if ((LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
+        supportedVectorVarShift(VT, Subtarget, LHS.getOpcode()) &&
+        ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
         sd_match(Cond, m_SetCC(m_Specific(LHS.getOperand(1)),
                                m_SpecificInt(VT.getScalarSizeInBits()),
                                m_SpecificCondCode(ISD::SETULT)))) {
@@ -46204,6 +46204,18 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
                                                      : X86ISD::VSHLV,
                          DL, VT, LHS.getOperand(0), LHS.getOperand(1));
     }
+    // fold select(icmp_uge(amt,BW),0,shl(x,amt)) -> avx2 psllv(x,amt)
+    // fold select(icmp_uge(amt,BW),0,srl(x,amt)) -> avx2 psrlv(x,amt)
+    if ((RHS.getOpcode() == ISD::SRL || RHS.getOpcode() == ISD::SHL) &&
+        supportedVectorVarShift(VT, Subtarget, RHS.getOpcode()) &&
+        ISD::isConstantSplatVectorAllZeros(LHS.getNode()) &&
+        sd_match(Cond, m_SetCC(m_Specific(RHS.getOperand(1)),
+                               m_SpecificInt(VT.getScalarSizeInBits()),
+                               m_SpecificCondCode(ISD::SETUGE)))) {
+      return DAG.getNode(RHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
+                                                     : X86ISD::VSHLV,
+                         DL, VT, RHS.getOperand(0), RHS.getOperand(1));
+    }
   }
 
   // Early exit check

diff  --git a/llvm/test/CodeGen/X86/combine-shl.ll b/llvm/test/CodeGen/X86/combine-shl.ll
index 8d8c1d26fc5ca..1ce10c3708d58 100644
--- a/llvm/test/CodeGen/X86/combine-shl.ll
+++ b/llvm/test/CodeGen/X86/combine-shl.ll
@@ -1086,19 +1086,10 @@ define <4 x i32> @combine_vec_shl_commuted_clamped1(<4 x i32> %sh, <4 x i32> %am
 ; SSE41-NEXT:    pand %xmm2, %xmm0
 ; SSE41-NEXT:    retq
 ;
-; AVX2-LABEL: combine_vec_shl_commuted_clamped1:
-; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm2 = [31,31,31,31]
-; AVX2-NEXT:    vpsllvd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpminud %xmm2, %xmm1, %xmm2
-; AVX2-NEXT:    vpcmpeqd %xmm2, %xmm1, %xmm1
-; AVX2-NEXT:    vpand %xmm0, %xmm1, %xmm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: combine_vec_shl_commuted_clamped1:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpsllvd %xmm1, %xmm0, %xmm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: combine_vec_shl_commuted_clamped1:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsllvd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    retq
   %cmp.i = icmp uge <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
   %shl = shl <4 x i32> %sh, %amt
   %1 = select <4 x i1> %cmp.i, <4 x i32> zeroinitializer, <4 x i32> %shl

diff  --git a/llvm/test/CodeGen/X86/combine-srl.ll b/llvm/test/CodeGen/X86/combine-srl.ll
index f2a9aa217f7ec..7bc90534dcc6e 100644
--- a/llvm/test/CodeGen/X86/combine-srl.ll
+++ b/llvm/test/CodeGen/X86/combine-srl.ll
@@ -828,19 +828,10 @@ define <4 x i32> @combine_vec_lshr_commuted_clamped1(<4 x i32> %sh, <4 x i32> %a
 ; SSE41-NEXT:    pand %xmm2, %xmm0
 ; SSE41-NEXT:    retq
 ;
-; AVX2-LABEL: combine_vec_lshr_commuted_clamped1:
-; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm2 = [31,31,31,31]
-; AVX2-NEXT:    vpsrlvd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpminud %xmm2, %xmm1, %xmm2
-; AVX2-NEXT:    vpcmpeqd %xmm2, %xmm1, %xmm1
-; AVX2-NEXT:    vpand %xmm0, %xmm1, %xmm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: combine_vec_lshr_commuted_clamped1:
-; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpsrlvd %xmm1, %xmm0, %xmm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: combine_vec_lshr_commuted_clamped1:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsrlvd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    retq
   %cmp.i = icmp uge <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
   %shr = lshr <4 x i32> %sh, %amt
   %1 = select <4 x i1> %cmp.i, <4 x i32> zeroinitializer, <4 x i32> %shr


        


More information about the llvm-commits mailing list