[llvm] r271033 - [X86] Detect SAD patterns and emit psadbw instructions.

Michael Kuperstein via llvm-commits llvm-commits at lists.llvm.org
Fri May 27 11:53:22 PDT 2016


Author: mkuper
Date: Fri May 27 13:53:22 2016
New Revision: 271033

URL: http://llvm.org/viewvc/llvm-project?rev=271033&view=rev
Log:
[X86] Detect SAD patterns and emit psadbw instructions.

This recommits r267649 with a fix for PR27539.

Differential Revision: http://reviews.llvm.org/D20598

Added:
    llvm/trunk/test/CodeGen/X86/sad.ll
      - copied, changed from r267722, llvm/trunk/test/CodeGen/X86/sad.ll
Modified:
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=271033&r1=271032&r2=271033&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Fri May 27 13:53:22 2016
@@ -13773,26 +13773,30 @@ SDValue DAGCombiner::visitSCALAR_TO_VECT
 
 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
   SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
   SDValue N2 = N->getOperand(2);
 
+  if (N0.getValueType() != N1.getValueType())
+    return SDValue();
+
   // If the input vector is a concatenation, and the insert replaces
   // one of the halves, we can optimize into a single concat_vectors.
-  if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
-      N0->getNumOperands() == 2 && N2.getOpcode() == ISD::Constant) {
+  if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0->getNumOperands() == 2 &&
+      N2.getOpcode() == ISD::Constant) {
     APInt InsIdx = cast<ConstantSDNode>(N2)->getAPIntValue();
     EVT VT = N->getValueType(0);
 
     // Lower half: fold (insert_subvector (concat_vectors X, Y), Z) ->
     // (concat_vectors Z, Y)
     if (InsIdx == 0)
-      return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
-                         N->getOperand(1), N0.getOperand(1));
+      return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N1,
+                         N0.getOperand(1));
 
     // Upper half: fold (insert_subvector (concat_vectors X, Y), Z) ->
     // (concat_vectors X, Z)
-    if (InsIdx == VT.getVectorNumElements()/2)
-      return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
-                         N0.getOperand(0), N->getOperand(1));
+    if (InsIdx == VT.getVectorNumElements() / 2)
+      return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0.getOperand(0),
+                         N1);
   }
 
   return SDValue();

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=271033&r1=271032&r2=271033&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Fri May 27 13:53:22 2016
@@ -29473,8 +29473,148 @@ static SDValue OptimizeConditionalInDecr
                      DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp);
 }
 
+static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG,
+                                const X86Subtarget &Subtarget) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+
+  if (!VT.isVector() || !VT.isSimple() ||
+      !(VT.getVectorElementType() == MVT::i32))
+    return SDValue();
+
+  unsigned RegSize = 128;
+  if (Subtarget.hasBWI())
+    RegSize = 512;
+  else if (Subtarget.hasAVX2())
+    RegSize = 256;
+
+  // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512.
+  if (VT.getSizeInBits() / 4 > RegSize)
+    return SDValue();
+
+  // Detect the following pattern:
+  //
+  // 1:    %2 = zext <N x i8> %0 to <N x i32>
+  // 2:    %3 = zext <N x i8> %1 to <N x i32>
+  // 3:    %4 = sub nsw <N x i32> %2, %3
+  // 4:    %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N]
+  // 5:    %6 = sub nsw <N x i32> zeroinitializer, %4
+  // 6:    %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6
+  // 7:    %8 = add nsw <N x i32> %7, %vec.phi
+  //
+  // The last instruction must be a reduction add. The instructions 3-6 forms an
+  // ABSDIFF pattern.
+
+  // The two operands of reduction add are from PHI and a select-op as in line 7
+  // above.
+  SDValue SelectOp, Phi;
+  if (Op0.getOpcode() == ISD::VSELECT) {
+    SelectOp = Op0;
+    Phi = Op1;
+  } else if (Op1.getOpcode() == ISD::VSELECT) {
+    SelectOp = Op1;
+    Phi = Op0;
+  } else
+    return SDValue();
+
+  // Check the condition of the select instruction is greater-than.
+  SDValue SetCC = SelectOp->getOperand(0);
+  if (SetCC.getOpcode() != ISD::SETCC)
+    return SDValue();
+  ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
+  if (CC != ISD::SETGT)
+    return SDValue();
+
+  Op0 = SelectOp->getOperand(1);
+  Op1 = SelectOp->getOperand(2);
+
+  // The second operand of SelectOp Op1 is the negation of the first operand
+  // Op0, which is implemented as 0 - Op0.
+  if (!(Op1.getOpcode() == ISD::SUB &&
+        ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) &&
+        Op1.getOperand(1) == Op0))
+    return SDValue();
+
+  // The first operand of SetCC is the first operand of SelectOp, which is the
+  // difference between two input vectors.
+  if (SetCC.getOperand(0) != Op0)
+    return SDValue();
+
+  // The second operand of > comparison can be either -1 or 0.
+  if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
+        ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode())))
+    return SDValue();
+
+  // The first operand of SelectOp is the difference between two input vectors.
+  if (Op0.getOpcode() != ISD::SUB)
+    return SDValue();
+
+  Op1 = Op0.getOperand(1);
+  Op0 = Op0.getOperand(0);
+
+  // Check if the operands of the diff are zero-extended from vectors of i8.
+  if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
+      Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+      Op1.getOpcode() != ISD::ZERO_EXTEND ||
+      Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
+    return SDValue();
+
+  // SAD pattern detected. Now build a SAD instruction and an addition for
+  // reduction. Note that the number of elments of the result of SAD is less
+  // than the number of elements of its input. Therefore, we could only update
+  // part of elements in the reduction vector.
+
+  // Legalize the type of the inputs of PSADBW.
+  EVT InVT = Op0.getOperand(0).getValueType();
+  if (InVT.getSizeInBits() <= 128)
+    RegSize = 128;
+  else if (InVT.getSizeInBits() <= 256)
+    RegSize = 256;
+
+  unsigned NumConcat = RegSize / InVT.getSizeInBits();
+  SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
+  Ops[0] = Op0.getOperand(0);
+  MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
+  Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+  Ops[0] = Op1.getOperand(0);
+  Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+
+  // The output of PSADBW is a vector of i64.
+  MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
+  SDValue Sad = DAG.getNode(X86ISD::PSADBW, DL, SadVT, Op0, Op1);
+
+  // We need to turn the vector of i64 into a vector of i32.
+  // If the reduction vector is at least as wide as the psadbw result, just
+  // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
+  // anyway.
+  MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);  
+  if (VT.getSizeInBits() >= ResVT.getSizeInBits())
+    Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
+  else 
+    Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad);
+
+  if (VT.getSizeInBits() > ResVT.getSizeInBits()) {
+    // Update part of elements of the reduction vector. This is done by first
+    // extracting a sub-vector from it, updating this sub-vector, and inserting
+    // it back.
+    SDValue SubPhi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Phi,
+                                 DAG.getIntPtrConstant(0, DL));
+    SDValue Res = DAG.getNode(ISD::ADD, DL, ResVT, Sad, SubPhi);
+    return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Phi, Res,
+                       DAG.getIntPtrConstant(0, DL));
+  } else
+    return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi);
+}
+
 static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
                           const X86Subtarget &Subtarget) {
+  const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+  if (Flags->hasVectorReduction()) {
+    if (SDValue Sad = detectSADPattern(N, DAG, Subtarget))
+      return Sad;
+  }
   EVT VT = N->getValueType(0);
   SDValue Op0 = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);

Copied: llvm/trunk/test/CodeGen/X86/sad.ll (from r267722, llvm/trunk/test/CodeGen/X86/sad.ll)
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/sad.ll?p2=llvm/trunk/test/CodeGen/X86/sad.ll&p1=llvm/trunk/test/CodeGen/X86/sad.ll&r1=267722&r2=271033&rev=271033&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/sad.ll (original)
+++ llvm/trunk/test/CodeGen/X86/sad.ll Fri May 27 13:53:22 2016
@@ -971,3 +971,124 @@ middle.block:
   %12 = extractelement <64 x i32> %bin.rdx6, i32 0
   ret i32 %12
 }
+
+define i32 @sad_2i8() nounwind {
+; SSE2-LABEL: sad_2i8:
+; SSE2:       # BB#0: # %entry
+; SSE2-NEXT:    pxor %xmm0, %xmm0
+; SSE2-NEXT:    movq $-1024, %rax # imm = 0xFFFFFFFFFFFFFC00
+; SSE2-NEXT:    movl $65535, %ecx # imm = 0xFFFF
+; SSE2-NEXT:    movd %ecx, %xmm1
+; SSE2-NEXT:    .p2align 4, 0x90
+; SSE2-NEXT:  .LBB3_1: # %vector.body
+; SSE2-NEXT:    # =>This Inner Loop Header: Depth=1
+; SSE2-NEXT:    movd {{.*#+}} xmm2 = mem[0],zero,zero,zero
+; SSE2-NEXT:    movd {{.*#+}} xmm3 = mem[0],zero,zero,zero
+; SSE2-NEXT:    pand %xmm1, %xmm3
+; SSE2-NEXT:    pand %xmm1, %xmm2
+; SSE2-NEXT:    psadbw %xmm3, %xmm2
+; SSE2-NEXT:    paddq %xmm2, %xmm0
+; SSE2-NEXT:    addq $4, %rax
+; SSE2-NEXT:    jne .LBB3_1
+; SSE2-NEXT:  # BB#2: # %middle.block
+; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; SSE2-NEXT:    paddq %xmm0, %xmm1
+; SSE2-NEXT:    movd %xmm1, %eax
+; SSE2-NEXT:    retq
+;
+; AVX2-LABEL: sad_2i8:
+; AVX2:       # BB#0: # %entry
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    movq $-1024, %rax # imm = 0xFFFFFFFFFFFFFC00
+; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT:    .p2align 4, 0x90
+; AVX2-NEXT:  .LBB3_1: # %vector.body
+; AVX2-NEXT:    # =>This Inner Loop Header: Depth=1
+; AVX2-NEXT:    vmovd {{.*#+}} xmm2 = mem[0],zero,zero,zero
+; AVX2-NEXT:    vmovd {{.*#+}} xmm3 = mem[0],zero,zero,zero
+; AVX2-NEXT:    vpblendw {{.*#+}} xmm3 = xmm3[0],xmm0[1,2,3,4,5,6,7]
+; AVX2-NEXT:    vpblendw {{.*#+}} xmm2 = xmm2[0],xmm0[1,2,3,4,5,6,7]
+; AVX2-NEXT:    vpsadbw %xmm3, %xmm2, %xmm2
+; AVX2-NEXT:    vpaddq %xmm1, %xmm2, %xmm1
+; AVX2-NEXT:    addq $4, %rax
+; AVX2-NEXT:    jne .LBB3_1
+; AVX2-NEXT:  # BB#2: # %middle.block
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[2,3,0,1]
+; AVX2-NEXT:    vpaddq %xmm0, %xmm1, %xmm0
+; AVX2-NEXT:    vmovd %xmm0, %eax
+; AVX2-NEXT:    retq
+;
+; AVX512F-LABEL: sad_2i8:
+; AVX512F:       # BB#0: # %entry
+; AVX512F-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX512F-NEXT:    movq $-1024, %rax # imm = 0xFFFFFFFFFFFFFC00
+; AVX512F-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512F-NEXT:    .p2align 4, 0x90
+; AVX512F-NEXT:  .LBB3_1: # %vector.body
+; AVX512F-NEXT:    # =>This Inner Loop Header: Depth=1
+; AVX512F-NEXT:    vmovd {{.*#+}} xmm2 = mem[0],zero,zero,zero
+; AVX512F-NEXT:    vmovd {{.*#+}} xmm3 = mem[0],zero,zero,zero
+; AVX512F-NEXT:    vpblendw {{.*#+}} xmm3 = xmm3[0],xmm0[1,2,3,4,5,6,7]
+; AVX512F-NEXT:    vpblendw {{.*#+}} xmm2 = xmm2[0],xmm0[1,2,3,4,5,6,7]
+; AVX512F-NEXT:    vpsadbw %xmm3, %xmm2, %xmm2
+; AVX512F-NEXT:    vpaddq %xmm1, %xmm2, %xmm1
+; AVX512F-NEXT:    addq $4, %rax
+; AVX512F-NEXT:    jne .LBB3_1
+; AVX512F-NEXT:  # BB#2: # %middle.block
+; AVX512F-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[2,3,0,1]
+; AVX512F-NEXT:    vpaddq %xmm0, %xmm1, %xmm0
+; AVX512F-NEXT:    vmovd %xmm0, %eax
+; AVX512F-NEXT:    retq
+;
+; AVX512BW-LABEL: sad_2i8:
+; AVX512BW:       # BB#0: # %entry
+; AVX512BW-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX512BW-NEXT:    movq $-1024, %rax # imm = 0xFFFFFFFFFFFFFC00
+; AVX512BW-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512BW-NEXT:    .p2align 4, 0x90
+; AVX512BW-NEXT:  .LBB3_1: # %vector.body
+; AVX512BW-NEXT:    # =>This Inner Loop Header: Depth=1
+; AVX512BW-NEXT:    vmovd {{.*#+}} xmm2 = mem[0],zero,zero,zero
+; AVX512BW-NEXT:    vmovd {{.*#+}} xmm3 = mem[0],zero,zero,zero
+; AVX512BW-NEXT:    vpblendw {{.*#+}} xmm3 = xmm3[0],xmm0[1,2,3,4,5,6,7]
+; AVX512BW-NEXT:    vpblendw {{.*#+}} xmm2 = xmm2[0],xmm0[1,2,3,4,5,6,7]
+; AVX512BW-NEXT:    vpsadbw %xmm3, %xmm2, %xmm2
+; AVX512BW-NEXT:    vpaddq %xmm1, %xmm2, %xmm1
+; AVX512BW-NEXT:    addq $4, %rax
+; AVX512BW-NEXT:    jne .LBB3_1
+; AVX512BW-NEXT:  # BB#2: # %middle.block
+; AVX512BW-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[2,3,0,1]
+; AVX512BW-NEXT:    vpaddq %xmm0, %xmm1, %xmm0
+; AVX512BW-NEXT:    vmovd %xmm0, %eax
+; AVX512BW-NEXT:    retq
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %vec.phi = phi <2 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ]
+  %0 = getelementptr inbounds [1024 x i8], [1024 x i8]* @a, i64 0, i64 %index
+  %1 = bitcast i8* %0 to <2 x i8>*
+  %wide.load = load <2 x i8>, <2 x i8>* %1, align 4
+  %2 = zext <2 x i8> %wide.load to <2 x i32>
+  %3 = getelementptr inbounds [1024 x i8], [1024 x i8]* @b, i64 0, i64 %index
+  %4 = bitcast i8* %3 to <2 x i8>*
+  %wide.load1 = load <2 x i8>, <2 x i8>* %4, align 4
+  %5 = zext <2 x i8> %wide.load1 to <2 x i32>
+  %6 = sub nsw <2 x i32> %2, %5
+  %7 = icmp sgt <2 x i32> %6, <i32 -1, i32 -1>
+  %8 = sub nsw <2 x i32> zeroinitializer, %6
+  %9 = select <2 x i1> %7, <2 x i32> %6, <2 x i32> %8
+  %10 = add nsw <2 x i32> %9, %vec.phi
+  %index.next = add i64 %index, 4
+  %11 = icmp eq i64 %index.next, 1024
+  br i1 %11, label %middle.block, label %vector.body
+
+middle.block:
+  %.lcssa = phi <2 x i32> [ %10, %vector.body ]
+  %rdx.shuf = shufflevector <2 x i32> %.lcssa, <2 x i32> undef, <2 x i32> <i32 1, i32 undef>
+  %bin.rdx = add <2 x i32> %.lcssa, %rdx.shuf
+  %12 = extractelement <2 x i32> %bin.rdx, i32 0
+  ret i32 %12
+}
+




More information about the llvm-commits mailing list