[llvm] bc32a1d - [DAG] Add non-uniform vector support to (shl (sr[la] exact X, C1), C2) folds

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 12 04:58:05 PDT 2022


Author: Simon Pilgrim
Date: 2022-04-12T12:57:56+01:00
New Revision: bc32a1dd7679e74a29b6738c0d3beff9c4223796

URL: https://github.com/llvm/llvm-project/commit/bc32a1dd7679e74a29b6738c0d3beff9c4223796
DIFF: https://github.com/llvm/llvm-project/commit/bc32a1dd7679e74a29b6738c0d3beff9c4223796.diff

LOG: [DAG] Add non-uniform vector support to (shl (sr[la] exact X,  C1), C2) folds

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/combine-shl.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 17e128149a604..88f68338ffa22 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8888,23 +8888,36 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
   }
 
   // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
-  // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1  > C2
-  // TODO - support non-uniform vector shift amounts.
-  ConstantSDNode *N1C = isConstOrConstSplat(N1);
-  if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
+  // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
+  if ((N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
       N0->getFlags().hasExact()) {
-    if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
-      uint64_t C1 = N0C1->getZExtValue();
-      uint64_t C2 = N1C->getZExtValue();
+    auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
+                                           ConstantSDNode *RHS) {
+      const APInt &LHSC = LHS->getAPIntValue();
+      const APInt &RHSC = RHS->getAPIntValue();
+      return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
+             LHSC.getZExtValue() <= RHSC.getZExtValue();
+    };
+    if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
+                                  /*AllowUndefs*/ false,
+                                  /*AllowTypeMismatch*/ true)) {
       SDLoc DL(N);
-      if (C1 <= C2)
-        return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
-                           DAG.getConstant(C2 - C1, DL, ShiftVT));
-      return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0),
-                         DAG.getConstant(C1 - C2, DL, ShiftVT));
+      SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
+      SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
+      return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
+    }
+    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
+                                  /*AllowUndefs*/ false,
+                                  /*AllowTypeMismatch*/ true)) {
+      SDLoc DL(N);
+      SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
+      SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
+      return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
     }
   }
 
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
+
   // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
   //                               (and (srl x, (sub c1, c2), MASK)
   // Only fold this if the inner shift has no other uses -- if it does, folding

diff  --git a/llvm/test/CodeGen/X86/combine-shl.ll b/llvm/test/CodeGen/X86/combine-shl.ll
index 6c90ba76fa4b7..5d7d9730f1252 100644
--- a/llvm/test/CodeGen/X86/combine-shl.ll
+++ b/llvm/test/CodeGen/X86/combine-shl.ll
@@ -418,39 +418,21 @@ define <4 x i32> @combine_vec_shl_ge_ashr_exact1(<4 x i32> %x) {
 ; SSE2-LABEL: combine_vec_shl_ge_ashr_exact1:
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    movdqa %xmm0, %xmm1
-; SSE2-NEXT:    psrad $3, %xmm1
-; SSE2-NEXT:    movdqa %xmm0, %xmm2
-; SSE2-NEXT:    psrad $5, %xmm2
-; SSE2-NEXT:    movsd {{.*#+}} xmm2 = xmm1[0],xmm2[1]
-; SSE2-NEXT:    movdqa %xmm0, %xmm1
-; SSE2-NEXT:    psrad $8, %xmm1
-; SSE2-NEXT:    psrad $4, %xmm0
-; SSE2-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,1],xmm1[3,3]
-; SSE2-NEXT:    pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[0,2,2,3]
-; SSE2-NEXT:    pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[0,2,2,3]
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; SSE2-NEXT:    pslld $2, %xmm1
+; SSE2-NEXT:    shufps {{.*#+}} xmm0 = xmm0[3,0],xmm1[2,0]
+; SSE2-NEXT:    shufps {{.*#+}} xmm1 = xmm1[0,1],xmm0[2,0]
+; SSE2-NEXT:    movaps %xmm1, %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; SSE41-LABEL: combine_vec_shl_ge_ashr_exact1:
 ; SSE41:       # %bb.0:
 ; SSE41-NEXT:    movdqa %xmm0, %xmm1
-; SSE41-NEXT:    psrad $8, %xmm1
-; SSE41-NEXT:    movdqa %xmm0, %xmm2
-; SSE41-NEXT:    psrad $4, %xmm2
-; SSE41-NEXT:    pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7]
-; SSE41-NEXT:    movdqa %xmm0, %xmm1
-; SSE41-NEXT:    psrad $5, %xmm1
-; SSE41-NEXT:    psrad $3, %xmm0
-; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7]
-; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7]
-; SSE41-NEXT:    pmulld {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE41-NEXT:    pslld $2, %xmm1
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4,5],xmm0[6,7]
 ; SSE41-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vec_shl_ge_ashr_exact1:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpsravd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = ashr exact <4 x i32> %x, <i32 3, i32 4, i32 5, i32 8>
@@ -495,40 +477,22 @@ define <4 x i32> @combine_vec_shl_lt_ashr_exact1(<4 x i32> %x) {
 ; SSE2-LABEL: combine_vec_shl_lt_ashr_exact1:
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    movdqa %xmm0, %xmm1
-; SSE2-NEXT:    psrad $5, %xmm1
-; SSE2-NEXT:    movdqa %xmm0, %xmm2
-; SSE2-NEXT:    psrad $7, %xmm2
-; SSE2-NEXT:    movsd {{.*#+}} xmm2 = xmm1[0],xmm2[1]
-; SSE2-NEXT:    movdqa %xmm0, %xmm1
-; SSE2-NEXT:    psrad $8, %xmm1
-; SSE2-NEXT:    psrad $6, %xmm0
-; SSE2-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,1],xmm1[3,3]
-; SSE2-NEXT:    pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[0,2,2,3]
-; SSE2-NEXT:    pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[0,2,2,3]
-; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
+; SSE2-NEXT:    psrad $2, %xmm1
+; SSE2-NEXT:    shufps {{.*#+}} xmm0 = xmm0[3,0],xmm1[2,0]
+; SSE2-NEXT:    shufps {{.*#+}} xmm1 = xmm1[0,1],xmm0[2,0]
+; SSE2-NEXT:    movaps %xmm1, %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; SSE41-LABEL: combine_vec_shl_lt_ashr_exact1:
 ; SSE41:       # %bb.0:
 ; SSE41-NEXT:    movdqa %xmm0, %xmm1
-; SSE41-NEXT:    psrad $8, %xmm1
-; SSE41-NEXT:    movdqa %xmm0, %xmm2
-; SSE41-NEXT:    psrad $6, %xmm2
-; SSE41-NEXT:    pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7]
-; SSE41-NEXT:    movdqa %xmm0, %xmm1
-; SSE41-NEXT:    psrad $7, %xmm1
-; SSE41-NEXT:    psrad $5, %xmm0
-; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7]
-; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7]
-; SSE41-NEXT:    pmulld {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE41-NEXT:    psrad $2, %xmm1
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4,5],xmm0[6,7]
 ; SSE41-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vec_shl_lt_ashr_exact1:
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vpsravd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX-NEXT:    vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = ashr exact <4 x i32> %x, <i32 5, i32 6, i32 7, i32 8>
   %2 = shl <4 x i32> %1, <i32 3, i32 4, i32 5, i32 8>


        


More information about the llvm-commits mailing list