[llvm] r367195 - [X86] In combineLoopMAddPattern and combineLoopSADPattern, preserve the vector reduction flag on the final add. Handle unrolled loops by letting DAG combine revisit.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 28 11:45:43 PDT 2019


Author: ctopper
Date: Sun Jul 28 11:45:42 2019
New Revision: 367195

URL: http://llvm.org/viewvc/llvm-project?rev=367195&view=rev
Log:
[X86] In combineLoopMAddPattern and combineLoopSADPattern, preserve the vector reduction flag on the final add. Handle unrolled loops by letting DAG combine revisit.

This reverts r340478 and r340631 and replaces them with a simpler
method of just letting DAG combine revisit the nodes to handle
the other operand.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/madd.ll
    llvm/trunk/test/CodeGen/X86/sad.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=367195&r1=367194&r2=367195&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Sun Jul 28 11:45:42 2019
@@ -43151,8 +43151,17 @@ static SDValue combineLoopMAddPattern(SD
   if (!Subtarget.hasSSE2())
     return SDValue();
 
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
+  SDValue MulOp = N->getOperand(0);
+  SDValue OtherOp = N->getOperand(1);
+
+  if (MulOp.getOpcode() != ISD::MUL)
+    std::swap(MulOp, OtherOp);
+  if (MulOp.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  ShrinkMode Mode;
+  if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16)
+    return SDValue();
 
   EVT VT = N->getValueType(0);
 
@@ -43161,49 +43170,33 @@ static SDValue combineLoopMAddPattern(SD
   if (!VT.isVector() || VT.getVectorNumElements() < 8)
     return SDValue();
 
-  if (Op0.getOpcode() != ISD::MUL)
-    std::swap(Op0, Op1);
-  if (Op0.getOpcode() != ISD::MUL)
-    return SDValue();
-
-  ShrinkMode Mode;
-  if (!canReduceVMulWidth(Op0.getNode(), DAG, Mode) || Mode == MULU16)
-    return SDValue();
-
   SDLoc DL(N);
   EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
                                    VT.getVectorNumElements());
   EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
                                 VT.getVectorNumElements() / 2);
 
+  // Shrink the operands of mul.
+  SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0));
+  SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1));
+
   // Madd vector size is half of the original vector size
   auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
                            ArrayRef<SDValue> Ops) {
     MVT OpVT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
     return DAG.getNode(X86ISD::VPMADDWD, DL, OpVT, Ops);
   };
-
-  auto BuildPMADDWD = [&](SDValue Mul) {
-    // Shrink the operands of mul.
-    SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(0));
-    SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(1));
-
-    SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 },
-                                   PMADDWDBuilder);
-    // Fill the rest of the output with 0
-    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd,
-                       DAG.getConstant(0, DL, MAddVT));
-  };
-
-  Op0 = BuildPMADDWD(Op0);
-
-  // It's possible that Op1 is also a mul we can reduce.
-  if (Op1.getOpcode() == ISD::MUL &&
-      canReduceVMulWidth(Op1.getNode(), DAG, Mode) && Mode != MULU16) {
-    Op1 = BuildPMADDWD(Op1);
-  }
-
-  return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1);
+  SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 },
+                                  PMADDWDBuilder);
+  // Fill the rest of the output with 0
+  SDValue Zero = DAG.getConstant(0, DL, Madd.getSimpleValueType());
+  SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero);
+
+  // Preserve the reduction flag on the ADD. We may need to revisit for the
+  // other operand.
+  SDNodeFlags Flags;
+  Flags.setVectorReduction(true);
+  return DAG.getNode(ISD::ADD, DL, VT, Concat, OtherOp, Flags);
 }
 
 static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG,
@@ -43213,8 +43206,6 @@ static SDValue combineLoopSADPattern(SDN
 
   SDLoc DL(N);
   EVT VT = N->getValueType(0);
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
 
   // TODO: There's nothing special about i32, any integer type above i16 should
   // work just as well.
@@ -43234,55 +43225,49 @@ static SDValue combineLoopSADPattern(SDN
   if (VT.getSizeInBits() / 4 > RegSize)
     return SDValue();
 
-  // We know N is a reduction add, which means one of its operands is a phi.
-  // To match SAD, we need the other operand to be a ABS.
-  if (Op0.getOpcode() != ISD::ABS)
-    std::swap(Op0, Op1);
-  if (Op0.getOpcode() != ISD::ABS)
-    return SDValue();
-
-  auto BuildPSADBW = [&](SDValue Op0, SDValue Op1) {
-    // SAD pattern detected. Now build a SAD instruction and an addition for
-    // reduction. Note that the number of elements of the result of SAD is less
-    // than the number of elements of its input. Therefore, we could only update
-    // part of elements in the reduction vector.
-    SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget);
-
-    // The output of PSADBW is a vector of i64.
-    // We need to turn the vector of i64 into a vector of i32.
-    // If the reduction vector is at least as wide as the psadbw result, just
-    // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
-    // anyway.
-    MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32);
-    if (VT.getSizeInBits() >= ResVT.getSizeInBits())
-      Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
-    else
-      Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad);
-
-    if (VT.getSizeInBits() > ResVT.getSizeInBits()) {
-      // Fill the upper elements with zero to match the add width.
-      SDValue Zero = DAG.getConstant(0, DL, VT);
-      Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad,
-                        DAG.getIntPtrConstant(0, DL));
-    }
-
-    return Sad;
-  };
+  // We know N is a reduction add. To match SAD, we need one of the operands to
+  // be an ABS.
+  SDValue AbsOp = N->getOperand(0);
+  SDValue OtherOp = N->getOperand(1);
+  if (AbsOp.getOpcode() != ISD::ABS)
+    std::swap(AbsOp, OtherOp);
+  if (AbsOp.getOpcode() != ISD::ABS)
+    return SDValue();
 
   // Check whether we have an abs-diff pattern feeding into the select.
   SDValue SadOp0, SadOp1;
-  if (!detectZextAbsDiff(Op0, SadOp0, SadOp1))
+  if(!detectZextAbsDiff(AbsOp, SadOp0, SadOp1))
     return SDValue();
 
-  Op0 = BuildPSADBW(SadOp0, SadOp1);
-
-  // It's possible we have a sad on the other side too.
-  if (Op1.getOpcode() == ISD::ABS &&
-      detectZextAbsDiff(Op1, SadOp0, SadOp1)) {
-    Op1 = BuildPSADBW(SadOp0, SadOp1);
+  // SAD pattern detected. Now build a SAD instruction and an addition for
+  // reduction. Note that the number of elements of the result of SAD is less
+  // than the number of elements of its input. Therefore, we could only update
+  // part of elements in the reduction vector.
+  SDValue Sad = createPSADBW(DAG, SadOp0, SadOp1, DL, Subtarget);
+
+  // The output of PSADBW is a vector of i64.
+  // We need to turn the vector of i64 into a vector of i32.
+  // If the reduction vector is at least as wide as the psadbw result, just
+  // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
+  // anyway.
+  MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32);
+  if (VT.getSizeInBits() >= ResVT.getSizeInBits())
+    Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
+  else
+    Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad);
+
+  if (VT.getSizeInBits() > ResVT.getSizeInBits()) {
+    // Fill the upper elements with zero to match the add width.
+    SDValue Zero = DAG.getConstant(0, DL, VT);
+    Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad,
+                      DAG.getIntPtrConstant(0, DL));
   }
 
-  return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1);
+  // Preserve the reduction flag on the ADD. We may need to revisit for the
+  // other operand.
+  SDNodeFlags Flags;
+  Flags.setVectorReduction(true);
+  return DAG.getNode(ISD::ADD, DL, VT, Sad, OtherOp, Flags);
 }
 
 /// Convert vector increment or decrement to sub/add with an all-ones constant:

Modified: llvm/trunk/test/CodeGen/X86/madd.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/madd.ll?rev=367195&r1=367194&r2=367195&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/madd.ll (original)
+++ llvm/trunk/test/CodeGen/X86/madd.ll Sun Jul 28 11:45:42 2019
@@ -2677,9 +2677,9 @@ define i32 @madd_double_reduction(<8 x i
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vmovdqu (%rdi), %xmm0
 ; AVX-NEXT:    vmovdqu (%rdx), %xmm1
-; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vpmaddwd (%rcx), %xmm1, %xmm1
-; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
+; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
 ; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]

Modified: llvm/trunk/test/CodeGen/X86/sad.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/sad.ll?rev=367195&r1=367194&r2=367195&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/sad.ll (original)
+++ llvm/trunk/test/CodeGen/X86/sad.ll Sun Jul 28 11:45:42 2019
@@ -1403,18 +1403,18 @@ define i32 @sad_unroll_nonzero_initial(<
 ; SSE2-NEXT:    movdqu (%rdi), %xmm0
 ; SSE2-NEXT:    movdqu (%rsi), %xmm1
 ; SSE2-NEXT:    psadbw %xmm0, %xmm1
-; SSE2-NEXT:    movdqu (%rdx), %xmm0
-; SSE2-NEXT:    movdqu (%rcx), %xmm2
-; SSE2-NEXT:    psadbw %xmm0, %xmm2
 ; SSE2-NEXT:    movl $1, %eax
 ; SSE2-NEXT:    movd %eax, %xmm0
-; SSE2-NEXT:    paddd %xmm2, %xmm0
-; SSE2-NEXT:    paddd %xmm1, %xmm0
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; SSE2-NEXT:    movdqu (%rdx), %xmm2
+; SSE2-NEXT:    movdqu (%rcx), %xmm3
+; SSE2-NEXT:    psadbw %xmm2, %xmm3
+; SSE2-NEXT:    paddd %xmm0, %xmm3
+; SSE2-NEXT:    paddd %xmm1, %xmm3
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm3[2,3,0,1]
+; SSE2-NEXT:    paddd %xmm3, %xmm0
+; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
 ; SSE2-NEXT:    paddd %xmm0, %xmm1
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[1,1,2,3]
-; SSE2-NEXT:    paddd %xmm1, %xmm0
-; SSE2-NEXT:    movd %xmm0, %eax
+; SSE2-NEXT:    movd %xmm1, %eax
 ; SSE2-NEXT:    retq
 ;
 ; AVX1-LABEL: sad_unroll_nonzero_initial:
@@ -1442,7 +1442,7 @@ define i32 @sad_unroll_nonzero_initial(<
 ; AVX2-NEXT:    vmovd %eax, %xmm1
 ; AVX2-NEXT:    vmovdqu (%rdx), %xmm2
 ; AVX2-NEXT:    vpsadbw (%rcx), %xmm2, %xmm2
-; AVX2-NEXT:    vpaddd %xmm1, %xmm2, %xmm1
+; AVX2-NEXT:    vpaddd %xmm2, %xmm1, %xmm1
 ; AVX2-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
 ; AVX2-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
@@ -1455,10 +1455,10 @@ define i32 @sad_unroll_nonzero_initial(<
 ; AVX512:       # %bb.0: # %bb
 ; AVX512-NEXT:    vmovdqu (%rdi), %xmm0
 ; AVX512-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
-; AVX512-NEXT:    vmovdqu (%rdx), %xmm1
-; AVX512-NEXT:    vpsadbw (%rcx), %xmm1, %xmm1
 ; AVX512-NEXT:    movl $1, %eax
-; AVX512-NEXT:    vmovd %eax, %xmm2
+; AVX512-NEXT:    vmovd %eax, %xmm1
+; AVX512-NEXT:    vmovdqu (%rdx), %xmm2
+; AVX512-NEXT:    vpsadbw (%rcx), %xmm2, %xmm2
 ; AVX512-NEXT:    vpaddd %zmm2, %zmm1, %zmm1
 ; AVX512-NEXT:    vpaddd %zmm1, %zmm0, %zmm0
 ; AVX512-NEXT:    vextracti64x4 $1, %zmm0, %ymm1
@@ -1526,9 +1526,9 @@ define i32 @sad_double_reduction(<16 x i
 ; AVX:       # %bb.0: # %bb
 ; AVX-NEXT:    vmovdqu (%rdi), %xmm0
 ; AVX-NEXT:    vmovdqu (%rdx), %xmm1
-; AVX-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vpsadbw (%rcx), %xmm1, %xmm1
-; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
+; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
 ; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]




More information about the llvm-commits mailing list