[llvm] r338097 - [X86] Add matching for another pattern of PMADDWD.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 26 21:29:11 PDT 2018


Author: ctopper
Date: Thu Jul 26 21:29:10 2018
New Revision: 338097

URL: http://llvm.org/viewvc/llvm-project?rev=338097&view=rev
Log:
[X86] Add matching for another pattern of PMADDWD.

Summary:
This is the pattern you get from the loop vectorizer for something like this

int16_t A[1024];
int16_t B[1024];
int32_t C[512];

void pmaddwd() {
  for (int i = 0; i != 512; ++i)
    C[i] = (A[2*i]*B[2*i]) + (A[2*i+1]*B[2*i+1]);
}

In this case we will have (add (mul (build_vector), (build_vector)), (mul (build_vector), (build_vector))). This is different than the pattern we currently match which has the build_vectors between an add and a single multiply. I'm not sure what C code would get you that pattern.

Reviewers: RKSimon, spatel, zvi

Reviewed By: zvi

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D49636

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/madd.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=338097&r1=338096&r2=338097&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Thu Jul 26 21:29:10 2018
@@ -38816,6 +38816,127 @@ static SDValue matchPMADDWD(SelectionDAG
                           PMADDBuilder);
 }
 
+// Attempt to turn this pattern into PMADDWD.
+// (mul (add (zext (build_vector)), (zext (build_vector))),
+//      (add (zext (build_vector)), (zext (build_vector)))
+static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
+                              const SDLoc &DL, EVT VT,
+                              const X86Subtarget &Subtarget) {
+  if (!Subtarget.hasSSE2())
+    return SDValue();
+
+  if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
+      VT.getVectorNumElements() < 4 ||
+      !isPowerOf2_32(VT.getVectorNumElements()))
+    return SDValue();
+
+  SDValue N00 = N0.getOperand(0);
+  SDValue N01 = N0.getOperand(1);
+  SDValue N10 = N1.getOperand(0);
+  SDValue N11 = N1.getOperand(1);
+
+  // All inputs need to be sign extends.
+  // TODO: Support ZERO_EXTEND from known positive?
+  if (N00.getOpcode() != ISD::SIGN_EXTEND ||
+      N01.getOpcode() != ISD::SIGN_EXTEND ||
+      N10.getOpcode() != ISD::SIGN_EXTEND ||
+      N11.getOpcode() != ISD::SIGN_EXTEND)
+    return SDValue();
+
+  // Peek through the extends.
+  N00 = N00.getOperand(0);
+  N01 = N01.getOperand(0);
+  N10 = N10.getOperand(0);
+  N11 = N11.getOperand(0);
+
+  // Must be extending from vXi16.
+  EVT InVT = N00.getValueType();
+  if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
+      N10.getValueType() != InVT || N11.getValueType() != InVT)
+    return SDValue();
+
+  // All inputs should be build_vectors.
+  if (N00.getOpcode() != ISD::BUILD_VECTOR ||
+      N01.getOpcode() != ISD::BUILD_VECTOR ||
+      N10.getOpcode() != ISD::BUILD_VECTOR ||
+      N11.getOpcode() != ISD::BUILD_VECTOR)
+    return SDValue();
+
+  // For each element, we need to ensure we have an odd element from one vector
+  // multiplied by the odd element of another vector and the even element from
+  // one of the same vectors being multiplied by the even element from the
+  // other vector. So we need to make sure for each element i, this operator
+  // is being performed:
+  //  A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
+  SDValue In0, In1;
+  for (unsigned i = 0; i != N00.getNumOperands(); ++i) {
+    SDValue N00Elt = N00.getOperand(i);
+    SDValue N01Elt = N01.getOperand(i);
+    SDValue N10Elt = N10.getOperand(i);
+    SDValue N11Elt = N11.getOperand(i);
+    // TODO: Be more tolerant to undefs.
+    if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+        N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+        N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+        N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+      return SDValue();
+    auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
+    auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
+    auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
+    auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
+    if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
+      return SDValue();
+    unsigned IdxN00 = ConstN00Elt->getZExtValue();
+    unsigned IdxN01 = ConstN01Elt->getZExtValue();
+    unsigned IdxN10 = ConstN10Elt->getZExtValue();
+    unsigned IdxN11 = ConstN11Elt->getZExtValue();
+    // Add is commutative so indices can be reordered.
+    if (IdxN00 > IdxN10) {
+      std::swap(IdxN00, IdxN10);
+      std::swap(IdxN01, IdxN11);
+    }
+    // N0 indices be the even elemtn. N1 indices must be the next odd element.
+    if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
+        IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
+      return SDValue();
+    SDValue N00In = N00Elt.getOperand(0);
+    SDValue N01In = N01Elt.getOperand(0);
+    SDValue N10In = N10Elt.getOperand(0);
+    SDValue N11In = N11Elt.getOperand(0);
+    // First time we find an input capture it.
+    if (!In0) {
+      In0 = N00In;
+      In1 = N01In;
+    }
+    // Mul is commutative so the input vectors can be in any order.
+    // Canonicalize to make the compares easier.
+    if (In0 != N00In)
+      std::swap(N00In, N01In);
+    if (In0 != N10In)
+      std::swap(N10In, N11In);
+    if (In0 != N00In || In1 != N01In || In0 != N10In || In1 != N11In)
+      return SDValue();
+  }
+
+  auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                         ArrayRef<SDValue> Ops) {
+    // Shrink by adding truncate nodes and let DAGCombine fold with the
+    // sources.
+    EVT InVT = Ops[0].getValueType();
+    assert(InVT.getScalarType() == MVT::i16 &&
+           "Unexpected scalar element type");
+    assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
+    EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
+                                 InVT.getVectorNumElements() / 2);
+    return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
+  };
+  return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 },
+                          PMADDBuilder);
+}
+
 static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
                           const X86Subtarget &Subtarget) {
   const SDNodeFlags Flags = N->getFlags();
@@ -38831,6 +38952,8 @@ static SDValue combineAdd(SDNode *N, Sel
 
   if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
     return MAdd;
+  if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
+    return MAdd;
 
   // Try to synthesize horizontal adds from adds of shuffles.
   if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 ||

Modified: llvm/trunk/test/CodeGen/X86/madd.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/madd.ll?rev=338097&r1=338096&r2=338097&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/madd.ll (original)
+++ llvm/trunk/test/CodeGen/X86/madd.ll Thu Jul 26 21:29:10 2018
@@ -2299,3 +2299,373 @@ define <32 x i32> @jumbled_indices32(<64
   %a = add <32 x i32> %sa, %sb
   ret <32 x i32> %a
 }
+
+; NOTE: We're testing with loads because ABI lowering creates a concat_vectors that extract_vector_elt creation can see through.
+; This would require the combine to recreate the concat_vectors.
+define <4 x i32> @pmaddwd_128(<8 x i16>* %Aptr, <8 x i16>* %Bptr) {
+; SSE2-LABEL: pmaddwd_128:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa (%rdi), %xmm0
+; SSE2-NEXT:    pmaddwd (%rsi), %xmm0
+; SSE2-NEXT:    retq
+;
+; AVX-LABEL: pmaddwd_128:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %A = load <8 x i16>, <8 x i16>* %Aptr
+  %B = load <8 x i16>, <8 x i16>* %Bptr
+  %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %A_even_ext = sext <4 x i16> %A_even to <4 x i32>
+  %B_even_ext = sext <4 x i16> %B_even to <4 x i32>
+  %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32>
+  %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32>
+  %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext
+  %odd_mul = mul <4 x i32> %A_odd_ext, %B_odd_ext
+  %add = add <4 x i32> %even_mul, %odd_mul
+  ret <4 x i32> %add
+}
+
+define <8 x i32> @pmaddwd_256(<16 x i16>* %Aptr, <16 x i16>* %Bptr) {
+; SSE2-LABEL: pmaddwd_256:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa (%rdi), %xmm0
+; SSE2-NEXT:    movdqa 16(%rdi), %xmm1
+; SSE2-NEXT:    pmaddwd (%rsi), %xmm0
+; SSE2-NEXT:    pmaddwd 16(%rsi), %xmm1
+; SSE2-NEXT:    retq
+;
+; AVX1-LABEL: pmaddwd_256:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX1-NEXT:    vmovdqa (%rsi), %ymm1
+; AVX1-NEXT:    vextractf128 $1, %ymm1, %xmm2
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm3
+; AVX1-NEXT:    vpmaddwd %xmm2, %xmm3, %xmm2
+; AVX1-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vinsertf128 $1, %xmm2, %ymm0, %ymm0
+; AVX1-NEXT:    retq
+;
+; AVX256-LABEL: pmaddwd_256:
+; AVX256:       # %bb.0:
+; AVX256-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX256-NEXT:    vpmaddwd (%rsi), %ymm0, %ymm0
+; AVX256-NEXT:    retq
+  %A = load <16 x i16>, <16 x i16>* %Aptr
+  %B = load <16 x i16>, <16 x i16>* %Bptr
+  %A_even = shufflevector <16 x i16> %A, <16 x i16> undef, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %A_odd = shufflevector <16 x i16> %A, <16 x i16> undef, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %B_even = shufflevector <16 x i16> %B, <16 x i16> undef, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %B_odd = shufflevector <16 x i16> %B, <16 x i16> undef, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %A_even_ext = sext <8 x i16> %A_even to <8 x i32>
+  %B_even_ext = sext <8 x i16> %B_even to <8 x i32>
+  %A_odd_ext = sext <8 x i16> %A_odd to <8 x i32>
+  %B_odd_ext = sext <8 x i16> %B_odd to <8 x i32>
+  %even_mul = mul <8 x i32> %A_even_ext, %B_even_ext
+  %odd_mul = mul <8 x i32> %A_odd_ext, %B_odd_ext
+  %add = add <8 x i32> %even_mul, %odd_mul
+  ret <8 x i32> %add
+}
+
+define <16 x i32> @pmaddwd_512(<32 x i16>* %Aptr, <32 x i16>* %Bptr) {
+; SSE2-LABEL: pmaddwd_512:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa (%rdi), %xmm0
+; SSE2-NEXT:    movdqa 16(%rdi), %xmm1
+; SSE2-NEXT:    movdqa 32(%rdi), %xmm2
+; SSE2-NEXT:    movdqa 48(%rdi), %xmm3
+; SSE2-NEXT:    pmaddwd (%rsi), %xmm0
+; SSE2-NEXT:    pmaddwd 16(%rsi), %xmm1
+; SSE2-NEXT:    pmaddwd 32(%rsi), %xmm2
+; SSE2-NEXT:    pmaddwd 48(%rsi), %xmm3
+; SSE2-NEXT:    retq
+;
+; AVX1-LABEL: pmaddwd_512:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX1-NEXT:    vmovdqa 32(%rdi), %ymm1
+; AVX1-NEXT:    vmovdqa (%rsi), %ymm2
+; AVX1-NEXT:    vmovdqa 32(%rsi), %ymm3
+; AVX1-NEXT:    vextractf128 $1, %ymm2, %xmm4
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm5
+; AVX1-NEXT:    vpmaddwd %xmm4, %xmm5, %xmm4
+; AVX1-NEXT:    vpmaddwd %xmm2, %xmm0, %xmm0
+; AVX1-NEXT:    vinsertf128 $1, %xmm4, %ymm0, %ymm0
+; AVX1-NEXT:    vextractf128 $1, %ymm3, %xmm2
+; AVX1-NEXT:    vextractf128 $1, %ymm1, %xmm4
+; AVX1-NEXT:    vpmaddwd %xmm2, %xmm4, %xmm2
+; AVX1-NEXT:    vpmaddwd %xmm3, %xmm1, %xmm1
+; AVX1-NEXT:    vinsertf128 $1, %xmm2, %ymm1, %ymm1
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: pmaddwd_512:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX2-NEXT:    vmovdqa 32(%rdi), %ymm1
+; AVX2-NEXT:    vpmaddwd (%rsi), %ymm0, %ymm0
+; AVX2-NEXT:    vpmaddwd 32(%rsi), %ymm1, %ymm1
+; AVX2-NEXT:    retq
+;
+; AVX512F-LABEL: pmaddwd_512:
+; AVX512F:       # %bb.0:
+; AVX512F-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX512F-NEXT:    vmovdqa 32(%rdi), %ymm1
+; AVX512F-NEXT:    vpmaddwd 32(%rsi), %ymm1, %ymm1
+; AVX512F-NEXT:    vpmaddwd (%rsi), %ymm0, %ymm0
+; AVX512F-NEXT:    vinserti64x4 $1, %ymm1, %zmm0, %zmm0
+; AVX512F-NEXT:    retq
+;
+; AVX512BW-LABEL: pmaddwd_512:
+; AVX512BW:       # %bb.0:
+; AVX512BW-NEXT:    vmovdqa64 (%rdi), %zmm0
+; AVX512BW-NEXT:    vpmaddwd (%rsi), %zmm0, %zmm0
+; AVX512BW-NEXT:    retq
+  %A = load <32 x i16>, <32 x i16>* %Aptr
+  %B = load <32 x i16>, <32 x i16>* %Bptr
+  %A_even = shufflevector <32 x i16> %A, <32 x i16> undef, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30>
+  %A_odd = shufflevector <32 x i16> %A, <32 x i16> undef, <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31>
+  %B_even = shufflevector <32 x i16> %B, <32 x i16> undef, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30>
+  %B_odd = shufflevector <32 x i16> %B, <32 x i16> undef, <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31>
+  %A_even_ext = sext <16 x i16> %A_even to <16 x i32>
+  %B_even_ext = sext <16 x i16> %B_even to <16 x i32>
+  %A_odd_ext = sext <16 x i16> %A_odd to <16 x i32>
+  %B_odd_ext = sext <16 x i16> %B_odd to <16 x i32>
+  %even_mul = mul <16 x i32> %A_even_ext, %B_even_ext
+  %odd_mul = mul <16 x i32> %A_odd_ext, %B_odd_ext
+  %add = add <16 x i32> %even_mul, %odd_mul
+  ret <16 x i32> %add
+}
+
+define <32 x i32> @pmaddwd_1024(<64 x i16>* %Aptr, <64 x i16>* %Bptr) {
+; SSE2-LABEL: pmaddwd_1024:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa 112(%rsi), %xmm0
+; SSE2-NEXT:    movdqa 96(%rsi), %xmm1
+; SSE2-NEXT:    movdqa 80(%rsi), %xmm2
+; SSE2-NEXT:    movdqa 64(%rsi), %xmm3
+; SSE2-NEXT:    movdqa (%rsi), %xmm4
+; SSE2-NEXT:    movdqa 16(%rsi), %xmm5
+; SSE2-NEXT:    movdqa 32(%rsi), %xmm6
+; SSE2-NEXT:    movdqa 48(%rsi), %xmm7
+; SSE2-NEXT:    pmaddwd (%rdx), %xmm4
+; SSE2-NEXT:    pmaddwd 16(%rdx), %xmm5
+; SSE2-NEXT:    pmaddwd 32(%rdx), %xmm6
+; SSE2-NEXT:    pmaddwd 48(%rdx), %xmm7
+; SSE2-NEXT:    pmaddwd 64(%rdx), %xmm3
+; SSE2-NEXT:    pmaddwd 80(%rdx), %xmm2
+; SSE2-NEXT:    pmaddwd 96(%rdx), %xmm1
+; SSE2-NEXT:    pmaddwd 112(%rdx), %xmm0
+; SSE2-NEXT:    movdqa %xmm0, 112(%rdi)
+; SSE2-NEXT:    movdqa %xmm1, 96(%rdi)
+; SSE2-NEXT:    movdqa %xmm2, 80(%rdi)
+; SSE2-NEXT:    movdqa %xmm3, 64(%rdi)
+; SSE2-NEXT:    movdqa %xmm7, 48(%rdi)
+; SSE2-NEXT:    movdqa %xmm6, 32(%rdi)
+; SSE2-NEXT:    movdqa %xmm5, 16(%rdi)
+; SSE2-NEXT:    movdqa %xmm4, (%rdi)
+; SSE2-NEXT:    movq %rdi, %rax
+; SSE2-NEXT:    retq
+;
+; AVX1-LABEL: pmaddwd_1024:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX1-NEXT:    vmovdqa 32(%rdi), %ymm1
+; AVX1-NEXT:    vmovdqa 64(%rdi), %ymm2
+; AVX1-NEXT:    vmovdqa 96(%rdi), %ymm8
+; AVX1-NEXT:    vmovdqa (%rsi), %ymm4
+; AVX1-NEXT:    vmovdqa 32(%rsi), %ymm5
+; AVX1-NEXT:    vmovdqa 64(%rsi), %ymm6
+; AVX1-NEXT:    vmovdqa 96(%rsi), %ymm9
+; AVX1-NEXT:    vextractf128 $1, %ymm4, %xmm3
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm7
+; AVX1-NEXT:    vpmaddwd %xmm3, %xmm7, %xmm3
+; AVX1-NEXT:    vpmaddwd %xmm4, %xmm0, %xmm0
+; AVX1-NEXT:    vinsertf128 $1, %xmm3, %ymm0, %ymm0
+; AVX1-NEXT:    vextractf128 $1, %ymm5, %xmm3
+; AVX1-NEXT:    vextractf128 $1, %ymm1, %xmm4
+; AVX1-NEXT:    vpmaddwd %xmm3, %xmm4, %xmm3
+; AVX1-NEXT:    vpmaddwd %xmm5, %xmm1, %xmm1
+; AVX1-NEXT:    vinsertf128 $1, %xmm3, %ymm1, %ymm1
+; AVX1-NEXT:    vextractf128 $1, %ymm6, %xmm3
+; AVX1-NEXT:    vextractf128 $1, %ymm2, %xmm4
+; AVX1-NEXT:    vpmaddwd %xmm3, %xmm4, %xmm3
+; AVX1-NEXT:    vpmaddwd %xmm6, %xmm2, %xmm2
+; AVX1-NEXT:    vinsertf128 $1, %xmm3, %ymm2, %ymm2
+; AVX1-NEXT:    vextractf128 $1, %ymm9, %xmm3
+; AVX1-NEXT:    vextractf128 $1, %ymm8, %xmm4
+; AVX1-NEXT:    vpmaddwd %xmm3, %xmm4, %xmm3
+; AVX1-NEXT:    vpmaddwd %xmm9, %xmm8, %xmm4
+; AVX1-NEXT:    vinsertf128 $1, %xmm3, %ymm4, %ymm3
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: pmaddwd_1024:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX2-NEXT:    vmovdqa 32(%rdi), %ymm1
+; AVX2-NEXT:    vmovdqa 64(%rdi), %ymm2
+; AVX2-NEXT:    vmovdqa 96(%rdi), %ymm3
+; AVX2-NEXT:    vpmaddwd (%rsi), %ymm0, %ymm0
+; AVX2-NEXT:    vpmaddwd 32(%rsi), %ymm1, %ymm1
+; AVX2-NEXT:    vpmaddwd 64(%rsi), %ymm2, %ymm2
+; AVX2-NEXT:    vpmaddwd 96(%rsi), %ymm3, %ymm3
+; AVX2-NEXT:    retq
+;
+; AVX512F-LABEL: pmaddwd_1024:
+; AVX512F:       # %bb.0:
+; AVX512F-NEXT:    vmovdqa (%rdi), %ymm0
+; AVX512F-NEXT:    vmovdqa 32(%rdi), %ymm1
+; AVX512F-NEXT:    vmovdqa 64(%rdi), %ymm2
+; AVX512F-NEXT:    vmovdqa 96(%rdi), %ymm3
+; AVX512F-NEXT:    vpmaddwd 32(%rsi), %ymm1, %ymm1
+; AVX512F-NEXT:    vpmaddwd (%rsi), %ymm0, %ymm0
+; AVX512F-NEXT:    vinserti64x4 $1, %ymm1, %zmm0, %zmm0
+; AVX512F-NEXT:    vpmaddwd 96(%rsi), %ymm3, %ymm1
+; AVX512F-NEXT:    vpmaddwd 64(%rsi), %ymm2, %ymm2
+; AVX512F-NEXT:    vinserti64x4 $1, %ymm1, %zmm2, %zmm1
+; AVX512F-NEXT:    retq
+;
+; AVX512BW-LABEL: pmaddwd_1024:
+; AVX512BW:       # %bb.0:
+; AVX512BW-NEXT:    vmovdqa64 (%rdi), %zmm0
+; AVX512BW-NEXT:    vmovdqa64 64(%rdi), %zmm1
+; AVX512BW-NEXT:    vpmaddwd (%rsi), %zmm0, %zmm0
+; AVX512BW-NEXT:    vpmaddwd 64(%rsi), %zmm1, %zmm1
+; AVX512BW-NEXT:    retq
+  %A = load <64 x i16>, <64 x i16>* %Aptr
+  %B = load <64 x i16>, <64 x i16>* %Bptr
+  %A_even = shufflevector <64 x i16> %A, <64 x i16> undef, <32 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30, i32 32, i32 34, i32 36, i32 38, i32 40, i32 42, i32 44, i32 46, i32 48, i32 50, i32 52, i32 54, i32 56, i32 58, i32 60, i32 62>
+  %A_odd = shufflevector <64 x i16> %A, <64 x i16> undef, <32 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31, i32 33, i32 35, i32 37, i32 39, i32 41, i32 43, i32 45, i32 47, i32 49, i32 51, i32 53, i32 55, i32 57, i32 59, i32 61, i32 63>
+  %B_even = shufflevector <64 x i16> %B, <64 x i16> undef, <32 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30, i32 32, i32 34, i32 36, i32 38, i32 40, i32 42, i32 44, i32 46, i32 48, i32 50, i32 52, i32 54, i32 56, i32 58, i32 60, i32 62>
+  %B_odd = shufflevector <64 x i16> %B, <64 x i16> undef, <32 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31, i32 33, i32 35, i32 37, i32 39, i32 41, i32 43, i32 45, i32 47, i32 49, i32 51, i32 53, i32 55, i32 57, i32 59, i32 61, i32 63>
+  %A_even_ext = sext <32 x i16> %A_even to <32 x i32>
+  %B_even_ext = sext <32 x i16> %B_even to <32 x i32>
+  %A_odd_ext = sext <32 x i16> %A_odd to <32 x i32>
+  %B_odd_ext = sext <32 x i16> %B_odd to <32 x i32>
+  %even_mul = mul <32 x i32> %A_even_ext, %B_even_ext
+  %odd_mul = mul <32 x i32> %A_odd_ext, %B_odd_ext
+  %add = add <32 x i32> %even_mul, %odd_mul
+  ret <32 x i32> %add
+}
+
+define <4 x i32> @pmaddwd_commuted_mul(<8 x i16>* %Aptr, <8 x i16>* %Bptr) {
+; SSE2-LABEL: pmaddwd_commuted_mul:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa (%rdi), %xmm0
+; SSE2-NEXT:    pmaddwd (%rsi), %xmm0
+; SSE2-NEXT:    retq
+;
+; AVX-LABEL: pmaddwd_commuted_mul:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %A = load <8 x i16>, <8 x i16>* %Aptr
+  %B = load <8 x i16>, <8 x i16>* %Bptr
+  %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %A_even_ext = sext <4 x i16> %A_even to <4 x i32>
+  %B_even_ext = sext <4 x i16> %B_even to <4 x i32>
+  %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32>
+  %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32>
+  %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext
+  %odd_mul = mul <4 x i32> %B_odd_ext, %A_odd_ext ; Different order than previous mul
+  %add = add <4 x i32> %even_mul, %odd_mul
+  ret <4 x i32> %add
+}
+
+define <4 x i32> @pmaddwd_swapped_indices(<8 x i16>* %Aptr, <8 x i16>* %Bptr) {
+; SSE2-LABEL: pmaddwd_swapped_indices:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa (%rdi), %xmm0
+; SSE2-NEXT:    pmaddwd (%rsi), %xmm0
+; SSE2-NEXT:    retq
+;
+; AVX-LABEL: pmaddwd_swapped_indices:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %A = load <8 x i16>, <8 x i16>* %Aptr
+  %B = load <8 x i16>, <8 x i16>* %Bptr
+  %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 1, i32 2, i32 5, i32 6> ; indices aren't all even
+  %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 0, i32 3, i32 4, i32 7> ; indices aren't all odd
+  %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 1, i32 2, i32 5, i32 6> ; same indices as A
+  %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 0, i32 3, i32 4, i32 7> ; same indices as A
+  %A_even_ext = sext <4 x i16> %A_even to <4 x i32>
+  %B_even_ext = sext <4 x i16> %B_even to <4 x i32>
+  %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32>
+  %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32>
+  %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext
+  %odd_mul = mul <4 x i32> %A_odd_ext, %B_odd_ext
+  %add = add <4 x i32> %even_mul, %odd_mul
+  ret <4 x i32> %add
+}
+
+; Negative test were indices aren't paired properly
+define <4 x i32> @pmaddwd_bad_indices(<8 x i16>* %Aptr, <8 x i16>* %Bptr) {
+; SSE2-LABEL: pmaddwd_bad_indices:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa (%rdi), %xmm0
+; SSE2-NEXT:    movdqa (%rsi), %xmm1
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm2 = xmm1[0,2,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,6,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm0[2,1,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,6,5,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm3[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm3[1,0,3,2,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm3, %xmm4
+; SSE2-NEXT:    pmulhw %xmm2, %xmm4
+; SSE2-NEXT:    pmullw %xmm2, %xmm3
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm3 = xmm3[0],xmm4[0],xmm3[1],xmm4[1],xmm3[2],xmm4[2],xmm3[3],xmm4[3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm0[0,3,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,7,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[3,1,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,7,5,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[1,0,3,2,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm0, %xmm2
+; SSE2-NEXT:    pmulhw %xmm1, %xmm2
+; SSE2-NEXT:    pmullw %xmm1, %xmm0
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
+; SSE2-NEXT:    paddd %xmm3, %xmm0
+; SSE2-NEXT:    retq
+;
+; AVX-LABEL: pmaddwd_bad_indices:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX-NEXT:    vmovdqa (%rsi), %xmm1
+; AVX-NEXT:    vpshufb {{.*#+}} xmm2 = xmm0[2,3,4,5,10,11,12,13,12,13,10,11,12,13,14,15]
+; AVX-NEXT:    vpmovsxwd %xmm2, %xmm2
+; AVX-NEXT:    vpshufb {{.*#+}} xmm3 = xmm1[0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
+; AVX-NEXT:    vpmovsxwd %xmm3, %xmm3
+; AVX-NEXT:    vpmulld %xmm3, %xmm2, %xmm2
+; AVX-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,1,6,7,8,9,14,15,8,9,14,15,12,13,14,15]
+; AVX-NEXT:    vpmovsxwd %xmm0, %xmm0
+; AVX-NEXT:    vpshufb {{.*#+}} xmm1 = xmm1[2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
+; AVX-NEXT:    vpmovsxwd %xmm1, %xmm1
+; AVX-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %A = load <8 x i16>, <8 x i16>* %Aptr
+  %B = load <8 x i16>, <8 x i16>* %Bptr
+  %A_even = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 1, i32 2, i32 5, i32 6>
+  %A_odd = shufflevector <8 x i16> %A, <8 x i16> undef, <4 x i32> <i32 0, i32 3, i32 4, i32 7>
+  %B_even = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6> ; different indices than A
+  %B_odd = shufflevector <8 x i16> %B, <8 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7> ; different indices than A
+  %A_even_ext = sext <4 x i16> %A_even to <4 x i32>
+  %B_even_ext = sext <4 x i16> %B_even to <4 x i32>
+  %A_odd_ext = sext <4 x i16> %A_odd to <4 x i32>
+  %B_odd_ext = sext <4 x i16> %B_odd to <4 x i32>
+  %even_mul = mul <4 x i32> %A_even_ext, %B_even_ext
+  %odd_mul = mul <4 x i32> %A_odd_ext, %B_odd_ext
+  %add = add <4 x i32> %even_mul, %odd_mul
+  ret <4 x i32> %add
+}




More information about the llvm-commits mailing list