[llvm] r340478 - [X86] Teach combineLoopSADPattern to handle cases where there is no loop and the add has two absolute difference inputs

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 22 16:19:01 PDT 2018


Author: ctopper
Date: Wed Aug 22 16:19:01 2018
New Revision: 340478

URL: http://llvm.org/viewvc/llvm-project?rev=340478&view=rev
Log:
[X86] Teach combineLoopSADPattern to handle cases where there is no loop and the add has two absolute difference inputs

Previously we asumed a vector reduction add is part of a loop and one of the input is a phi. But the code in SelectionDAGBuilder that sets vector reduction flag handles more cases than that. It just requires that the use chain ends in a horizontal reduction. And there are no other uses. This means it can handle unrolled reduction loops.

If the initial value of the reduction was 0, an unrolled loop would begin with a vector reduction add that has two sad inputs. Previously we would only transform one side of the add, but for this case we need to transform both sides.

I've created a lambda to reuse some of the code for both sides. And fixed the variables names to remove reference to "phi".

Differential Revision: https://reviews.llvm.org/D50817

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/sad.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=340478&r1=340477&r2=340478&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Aug 22 16:19:01 2018
@@ -39132,45 +39132,53 @@ static SDValue combineLoopSADPattern(SDN
 
   // We know N is a reduction add, which means one of its operands is a phi.
   // To match SAD, we need the other operand to be a vector select.
-  SDValue SelectOp, Phi;
-  if (Op0.getOpcode() == ISD::VSELECT) {
-    SelectOp = Op0;
-    Phi = Op1;
-  } else if (Op1.getOpcode() == ISD::VSELECT) {
-    SelectOp = Op1;
-    Phi = Op0;
-  } else
+  if (Op0.getOpcode() != ISD::VSELECT)
+    std::swap(Op0, Op1);
+  if (Op0.getOpcode() != ISD::VSELECT)
     return SDValue();
 
+  auto BuildPSADBW = [&](SDValue Op0, SDValue Op1) {
+    // SAD pattern detected. Now build a SAD instruction and an addition for
+    // reduction. Note that the number of elements of the result of SAD is less
+    // than the number of elements of its input. Therefore, we could only update
+    // part of elements in the reduction vector.
+    SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget);
+
+    // The output of PSADBW is a vector of i64.
+    // We need to turn the vector of i64 into a vector of i32.
+    // If the reduction vector is at least as wide as the psadbw result, just
+    // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
+    // anyway.
+    MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32);
+    if (VT.getSizeInBits() >= ResVT.getSizeInBits())
+      Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
+    else
+      Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad);
+
+    if (VT.getSizeInBits() > ResVT.getSizeInBits()) {
+      // Fill the upper elements with zero to match the add width.
+      SDValue Zero = DAG.getConstant(0, DL, VT);
+      Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad,
+                        DAG.getIntPtrConstant(0, DL));
+    }
+
+    return Sad;
+  };
+
   // Check whether we have an abs-diff pattern feeding into the select.
-  if(!detectZextAbsDiff(SelectOp, Op0, Op1))
+  SDValue SadOp0, SadOp1;
+  if (!detectZextAbsDiff(Op0, SadOp0, SadOp1))
     return SDValue();
 
-  // SAD pattern detected. Now build a SAD instruction and an addition for
-  // reduction. Note that the number of elements of the result of SAD is less
-  // than the number of elements of its input. Therefore, we could only update
-  // part of elements in the reduction vector.
-  SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget);
-
-  // The output of PSADBW is a vector of i64.
-  // We need to turn the vector of i64 into a vector of i32.
-  // If the reduction vector is at least as wide as the psadbw result, just
-  // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
-  // anyway.
-  MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32);
-  if (VT.getSizeInBits() >= ResVT.getSizeInBits())
-    Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
-  else
-    Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad);
-
-  if (VT.getSizeInBits() > ResVT.getSizeInBits()) {
-    // Fill the upper elements with zero to match the add width.
-    SDValue Zero = DAG.getConstant(0, DL, VT);
-    Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad,
-                      DAG.getIntPtrConstant(0, DL));
+  Op0 = BuildPSADBW(SadOp0, SadOp1);
+
+  // It's possible we have a sad on the other side too.
+  if (Op1.getOpcode() == ISD::VSELECT &&
+      detectZextAbsDiff(Op1, SadOp0, SadOp1)) {
+    Op1 = BuildPSADBW(SadOp0, SadOp1);
   }
 
-  return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi);
+  return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1);
 }
 
 /// Convert vector increment or decrement to sub/add with an all-ones constant:

Modified: llvm/trunk/test/CodeGen/X86/sad.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/sad.ll?rev=340478&r1=340477&r2=340478&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/sad.ll (original)
+++ llvm/trunk/test/CodeGen/X86/sad.ll Wed Aug 22 16:19:01 2018
@@ -1511,53 +1511,12 @@ define i32 @sad_double_reduction(<16 x i
 ; SSE2-LABEL: sad_double_reduction:
 ; SSE2:       # %bb.0: # %bb
 ; SSE2-NEXT:    movdqu (%rdi), %xmm0
-; SSE2-NEXT:    movdqu (%rsi), %xmm4
-; SSE2-NEXT:    pxor %xmm5, %xmm5
-; SSE2-NEXT:    movdqa %xmm0, %xmm3
-; SSE2-NEXT:    punpckhbw {{.*#+}} xmm3 = xmm3[8],xmm5[8],xmm3[9],xmm5[9],xmm3[10],xmm5[10],xmm3[11],xmm5[11],xmm3[12],xmm5[12],xmm3[13],xmm5[13],xmm3[14],xmm5[14],xmm3[15],xmm5[15]
-; SSE2-NEXT:    movdqa %xmm3, %xmm1
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm5[0],xmm1[1],xmm5[1],xmm1[2],xmm5[2],xmm1[3],xmm5[3]
-; SSE2-NEXT:    punpckhwd {{.*#+}} xmm3 = xmm3[4],xmm5[4],xmm3[5],xmm5[5],xmm3[6],xmm5[6],xmm3[7],xmm5[7]
-; SSE2-NEXT:    punpcklbw {{.*#+}} xmm0 = xmm0[0],xmm5[0],xmm0[1],xmm5[1],xmm0[2],xmm5[2],xmm0[3],xmm5[3],xmm0[4],xmm5[4],xmm0[5],xmm5[5],xmm0[6],xmm5[6],xmm0[7],xmm5[7]
-; SSE2-NEXT:    movdqa %xmm0, %xmm2
-; SSE2-NEXT:    punpckhwd {{.*#+}} xmm2 = xmm2[4],xmm5[4],xmm2[5],xmm5[5],xmm2[6],xmm5[6],xmm2[7],xmm5[7]
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm5[0],xmm0[1],xmm5[1],xmm0[2],xmm5[2],xmm0[3],xmm5[3]
-; SSE2-NEXT:    movdqa %xmm4, %xmm6
-; SSE2-NEXT:    punpckhbw {{.*#+}} xmm6 = xmm6[8],xmm5[8],xmm6[9],xmm5[9],xmm6[10],xmm5[10],xmm6[11],xmm5[11],xmm6[12],xmm5[12],xmm6[13],xmm5[13],xmm6[14],xmm5[14],xmm6[15],xmm5[15]
-; SSE2-NEXT:    movdqa %xmm6, %xmm7
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm7 = xmm7[0],xmm5[0],xmm7[1],xmm5[1],xmm7[2],xmm5[2],xmm7[3],xmm5[3]
-; SSE2-NEXT:    psubd %xmm7, %xmm1
-; SSE2-NEXT:    punpckhwd {{.*#+}} xmm6 = xmm6[4],xmm5[4],xmm6[5],xmm5[5],xmm6[6],xmm5[6],xmm6[7],xmm5[7]
-; SSE2-NEXT:    psubd %xmm6, %xmm3
-; SSE2-NEXT:    punpcklbw {{.*#+}} xmm4 = xmm4[0],xmm5[0],xmm4[1],xmm5[1],xmm4[2],xmm5[2],xmm4[3],xmm5[3],xmm4[4],xmm5[4],xmm4[5],xmm5[5],xmm4[6],xmm5[6],xmm4[7],xmm5[7]
-; SSE2-NEXT:    movdqa %xmm4, %xmm6
-; SSE2-NEXT:    punpckhwd {{.*#+}} xmm6 = xmm6[4],xmm5[4],xmm6[5],xmm5[5],xmm6[6],xmm5[6],xmm6[7],xmm5[7]
-; SSE2-NEXT:    psubd %xmm6, %xmm2
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm4 = xmm4[0],xmm5[0],xmm4[1],xmm5[1],xmm4[2],xmm5[2],xmm4[3],xmm5[3]
-; SSE2-NEXT:    psubd %xmm4, %xmm0
-; SSE2-NEXT:    movdqa %xmm1, %xmm4
-; SSE2-NEXT:    psrad $31, %xmm4
-; SSE2-NEXT:    paddd %xmm4, %xmm1
-; SSE2-NEXT:    pxor %xmm4, %xmm1
-; SSE2-NEXT:    movdqa %xmm3, %xmm4
-; SSE2-NEXT:    psrad $31, %xmm4
-; SSE2-NEXT:    paddd %xmm4, %xmm3
-; SSE2-NEXT:    pxor %xmm4, %xmm3
-; SSE2-NEXT:    movdqa %xmm2, %xmm4
-; SSE2-NEXT:    psrad $31, %xmm4
-; SSE2-NEXT:    paddd %xmm4, %xmm2
-; SSE2-NEXT:    pxor %xmm4, %xmm2
-; SSE2-NEXT:    paddd %xmm3, %xmm2
-; SSE2-NEXT:    movdqa %xmm0, %xmm3
-; SSE2-NEXT:    psrad $31, %xmm3
-; SSE2-NEXT:    paddd %xmm3, %xmm0
-; SSE2-NEXT:    pxor %xmm3, %xmm0
-; SSE2-NEXT:    paddd %xmm1, %xmm0
-; SSE2-NEXT:    paddd %xmm2, %xmm0
-; SSE2-NEXT:    movdqu (%rdx), %xmm1
+; SSE2-NEXT:    movdqu (%rsi), %xmm1
+; SSE2-NEXT:    psadbw %xmm0, %xmm1
+; SSE2-NEXT:    movdqu (%rdx), %xmm0
 ; SSE2-NEXT:    movdqu (%rcx), %xmm2
-; SSE2-NEXT:    psadbw %xmm1, %xmm2
-; SSE2-NEXT:    paddd %xmm0, %xmm2
+; SSE2-NEXT:    psadbw %xmm0, %xmm2
+; SSE2-NEXT:    paddd %xmm1, %xmm2
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,3,0,1]
 ; SSE2-NEXT:    paddd %xmm2, %xmm0
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
@@ -1567,26 +1526,9 @@ define i32 @sad_double_reduction(<16 x i
 ;
 ; AVX1-LABEL: sad_double_reduction:
 ; AVX1:       # %bb.0: # %bb
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm0 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm3 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpsubd %xmm4, %xmm0, %xmm0
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpsubd %xmm4, %xmm1, %xmm1
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpsubd %xmm4, %xmm2, %xmm2
-; AVX1-NEXT:    vpmovzxbd {{.*#+}} xmm4 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero
-; AVX1-NEXT:    vpsubd %xmm4, %xmm3, %xmm3
-; AVX1-NEXT:    vpabsd %xmm0, %xmm0
-; AVX1-NEXT:    vpabsd %xmm1, %xmm1
-; AVX1-NEXT:    vpabsd %xmm2, %xmm2
-; AVX1-NEXT:    vpaddd %xmm1, %xmm2, %xmm1
-; AVX1-NEXT:    vpabsd %xmm3, %xmm2
-; AVX1-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
-; AVX1-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vmovdqu (%rdi), %xmm0
 ; AVX1-NEXT:    vmovdqu (%rdx), %xmm1
+; AVX1-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
 ; AVX1-NEXT:    vpsadbw (%rcx), %xmm1, %xmm1
 ; AVX1-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
 ; AVX1-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
@@ -1597,16 +1539,9 @@ define i32 @sad_double_reduction(<16 x i
 ;
 ; AVX2-LABEL: sad_double_reduction:
 ; AVX2:       # %bb.0: # %bb
-; AVX2-NEXT:    vpmovzxbd {{.*#+}} ymm0 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
-; AVX2-NEXT:    vpmovzxbd {{.*#+}} ymm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
-; AVX2-NEXT:    vpmovzxbd {{.*#+}} ymm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
-; AVX2-NEXT:    vpsubd %ymm2, %ymm0, %ymm0
-; AVX2-NEXT:    vpmovzxbd {{.*#+}} ymm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero
-; AVX2-NEXT:    vpsubd %ymm2, %ymm1, %ymm1
-; AVX2-NEXT:    vpabsd %ymm0, %ymm0
-; AVX2-NEXT:    vpabsd %ymm1, %ymm1
-; AVX2-NEXT:    vpaddd %ymm0, %ymm1, %ymm0
+; AVX2-NEXT:    vmovdqu (%rdi), %xmm0
 ; AVX2-NEXT:    vmovdqu (%rdx), %xmm1
+; AVX2-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
 ; AVX2-NEXT:    vpsadbw (%rcx), %xmm1, %xmm1
 ; AVX2-NEXT:    vpaddd %ymm0, %ymm1, %ymm0
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
@@ -1620,11 +1555,9 @@ define i32 @sad_double_reduction(<16 x i
 ;
 ; AVX512-LABEL: sad_double_reduction:
 ; AVX512:       # %bb.0: # %bb
-; AVX512-NEXT:    vpmovzxbd {{.*#+}} zmm0 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
-; AVX512-NEXT:    vpmovzxbd {{.*#+}} zmm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
-; AVX512-NEXT:    vpsubd %zmm1, %zmm0, %zmm0
-; AVX512-NEXT:    vpabsd %zmm0, %zmm0
+; AVX512-NEXT:    vmovdqu (%rdi), %xmm0
 ; AVX512-NEXT:    vmovdqu (%rdx), %xmm1
+; AVX512-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
 ; AVX512-NEXT:    vpsadbw (%rcx), %xmm1, %xmm1
 ; AVX512-NEXT:    vpaddd %zmm0, %zmm1, %zmm0
 ; AVX512-NEXT:    vextracti64x4 $1, %zmm0, %ymm1




More information about the llvm-commits mailing list