[llvm] e694e19 - [x86] enhance matching of pmaddwd

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 30 04:28:43 PDT 2021


Author: Sanjay Patel
Date: 2021-03-30T07:28:33-04:00
New Revision: e694e19a793140b989364e5807630b635420533e

URL: https://github.com/llvm/llvm-project/commit/e694e19a793140b989364e5807630b635420533e
DIFF: https://github.com/llvm/llvm-project/commit/e694e19a793140b989364e5807630b635420533e.diff

LOG: [x86] enhance matching of pmaddwd

This was crashing with the example from:
https://llvm.org/PR49716
...and that was avoided with a283d7258360 ,
but as we can see from the SSE vs. AVX test code diff,
we can try harder to match the pattern.

This matcher code was adapted from another pmadd pattern
match in D49636, but it needs different ops to deal with
size mismatches.

Differential Revision: https://reviews.llvm.org/D99531

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/madd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9cbfae445a54..c434da6132d7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49045,10 +49045,10 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
       In0 = N00In;
       In1 = N01In;
 
-      // The input vector sizes must match the output.
-      // TODO: Insert cast ops to allow 
diff erent types.
-      if (In0.getValueSizeInBits() != VT.getSizeInBits() ||
-          In1.getValueSizeInBits() != VT.getSizeInBits())
+      // The input vectors must be at least as wide as the output.
+      // If they are larger than the output, we extract subvector below.
+      if (In0.getValueSizeInBits() < VT.getSizeInBits() ||
+          In1.getValueSizeInBits() < VT.getSizeInBits())
         return SDValue();
     }
     // Mul is commutative so the input vectors can be in any order.
@@ -49063,8 +49063,6 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
 
   auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
                          ArrayRef<SDValue> Ops) {
-    // Shrink by adding truncate nodes and let DAGCombine fold with the
-    // sources.
     EVT OpVT = Ops[0].getValueType();
     assert(OpVT.getScalarType() == MVT::i16 &&
            "Unexpected scalar element type");
@@ -49073,6 +49071,19 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
                                  OpVT.getVectorNumElements() / 2);
     return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
   };
+
+  // If the output is narrower than an input, extract the low part of the input
+  // vector.
+  EVT OutVT16 = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
+                               VT.getVectorNumElements() * 2);
+  if (OutVT16.bitsLT(In0.getValueType())) {
+    In0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT16, In0,
+                      DAG.getIntPtrConstant(0, DL));
+  }
+  if (OutVT16.bitsLT(In1.getValueType())) {
+    In1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT16, In1,
+                      DAG.getIntPtrConstant(0, DL));
+  }
   return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 },
                           PMADDBuilder);
 }

diff  --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll
index c0ff6a79ef18..96ced9f1d2a7 100644
--- a/llvm/test/CodeGen/X86/madd.ll
+++ b/llvm/test/CodeGen/X86/madd.ll
@@ -3052,48 +3052,12 @@ middle.block:
 define <4 x i32> @input_size_mismatch(<16 x i16> %x, <16 x i16>* %p) {
 ; SSE2-LABEL: input_size_mismatch:
 ; SSE2:       # %bb.0:
-; SSE2-NEXT:    movdqa (%rdi), %xmm1
-; SSE2-NEXT:    pshuflw {{.*#+}} xmm2 = xmm0[0,2,2,3,4,5,6,7]
-; SSE2-NEXT:    pshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,6,6,7]
-; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3]
-; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm0[3,1,2,3,4,5,6,7]
-; SSE2-NEXT:    pshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,7,5,6,7]
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm0[1,0,3,2,4,5,6,7]
-; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm1[0,2,2,3,4,5,6,7]
-; SSE2-NEXT:    pshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,4,6,6,7]
-; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm3[0,2,2,3]
-; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[3,1,2,3,4,5,6,7]
-; SSE2-NEXT:    pshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,7,5,6,7]
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
-; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[1,0,3,2,4,5,6,7]
-; SSE2-NEXT:    movdqa %xmm2, %xmm4
-; SSE2-NEXT:    pmulhw %xmm3, %xmm4
-; SSE2-NEXT:    pmullw %xmm3, %xmm2
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm4[0],xmm2[1],xmm4[1],xmm2[2],xmm4[2],xmm2[3],xmm4[3]
-; SSE2-NEXT:    movdqa %xmm0, %xmm3
-; SSE2-NEXT:    pmulhw %xmm1, %xmm3
-; SSE2-NEXT:    pmullw %xmm1, %xmm0
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm3[0],xmm0[1],xmm3[1],xmm0[2],xmm3[2],xmm0[3],xmm3[3]
-; SSE2-NEXT:    paddd %xmm2, %xmm0
+; SSE2-NEXT:    pmaddwd (%rdi), %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; AVX-LABEL: input_size_mismatch:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm1 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
-; AVX-NEXT:    vpshufb %xmm1, %xmm0, %xmm2
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm3 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX-NEXT:    vpshufb %xmm3, %xmm0, %xmm0
-; AVX-NEXT:    vmovdqa (%rdi), %xmm4
-; AVX-NEXT:    vpshufb %xmm1, %xmm4, %xmm1
-; AVX-NEXT:    vpshufb %xmm3, %xmm4, %xmm3
-; AVX-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX-NEXT:    vpmulld %xmm1, %xmm2, %xmm1
-; AVX-NEXT:    vpmovsxwd %xmm3, %xmm2
-; AVX-NEXT:    vpmulld %xmm2, %xmm0, %xmm0
-; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vpmaddwd (%rdi), %xmm0, %xmm0
 ; AVX-NEXT:    vzeroupper
 ; AVX-NEXT:    retq
   %y = load <16 x i16>, <16 x i16>* %p, align 32
@@ -3119,19 +3083,7 @@ define <4 x i32> @output_size_mismatch(<16 x i16> %x, <16 x i16> %y) {
 ;
 ; AVX-LABEL: output_size_mismatch:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
-; AVX-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
-; AVX-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
-; AVX-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
-; AVX-NEXT:    vpmovsxwd %xmm3, %xmm3
-; AVX-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
-; AVX-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vzeroupper
 ; AVX-NEXT:    retq
   %x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
@@ -3147,3 +3099,61 @@ define <4 x i32> @output_size_mismatch(<16 x i16> %x, <16 x i16> %y) {
   %r = add <4 x i32> %m0, %m1
   ret <4 x i32> %r
 }
+
+define <4 x i32> @output_size_mismatch_high_subvector(<16 x i16> %x, <16 x i16> %y) {
+; SSE2-LABEL: output_size_mismatch_high_subvector:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa %xmm1, %xmm0
+; SSE2-NEXT:    pmaddwd %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; AVX1-LABEL: output_size_mismatch_high_subvector:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm0
+; AVX1-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
+; AVX1-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
+; AVX1-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
+; AVX1-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
+; AVX1-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
+; AVX1-NEXT:    vpmovsxwd %xmm3, %xmm3
+; AVX1-NEXT:    vpmovsxwd %xmm0, %xmm0
+; AVX1-NEXT:    vpmovsxwd %xmm2, %xmm2
+; AVX1-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
+; AVX1-NEXT:    vpmovsxwd %xmm1, %xmm1
+; AVX1-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX1-NEXT:    vzeroupper
+; AVX1-NEXT:    retq
+;
+; AVX256-LABEL: output_size_mismatch_high_subvector:
+; AVX256:       # %bb.0:
+; AVX256-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
+; AVX256-NEXT:    vextracti128 $1, %ymm0, %xmm0
+; AVX256-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
+; AVX256-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
+; AVX256-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
+; AVX256-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
+; AVX256-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
+; AVX256-NEXT:    vpmovsxwd %xmm3, %xmm3
+; AVX256-NEXT:    vpmovsxwd %xmm0, %xmm0
+; AVX256-NEXT:    vpmovsxwd %xmm2, %xmm2
+; AVX256-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
+; AVX256-NEXT:    vpmovsxwd %xmm1, %xmm1
+; AVX256-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX256-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX256-NEXT:    vzeroupper
+; AVX256-NEXT:    retq
+  %x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 8, i32 10, i32 12, i32 14>
+  %x1 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 9, i32 11, i32 13, i32 15>
+  %y0 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %y1 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %sx0 = sext <4 x i16> %x0 to <4 x i32>
+  %sx1 = sext <4 x i16> %x1 to <4 x i32>
+  %sy0 = sext <4 x i16> %y0 to <4 x i32>
+  %sy1 = sext <4 x i16> %y1 to <4 x i32>
+  %m0 = mul <4 x i32> %sx0, %sy0
+  %m1 = mul <4 x i32> %sx1, %sy1
+  %r = add <4 x i32> %m0, %m1
+  ret <4 x i32> %r
+}


        


More information about the llvm-commits mailing list