[llvm] a283d72 - [x86] prevent crashing while matching pmaddwd

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 27 02:32:41 PDT 2021


Author: Sanjay Patel
Date: 2021-03-27T05:27:14-04:00
New Revision: a283d725836033f5d7626470506160b7bf6d9107

URL: https://github.com/llvm/llvm-project/commit/a283d725836033f5d7626470506160b7bf6d9107
DIFF: https://github.com/llvm/llvm-project/commit/a283d725836033f5d7626470506160b7bf6d9107.diff

LOG: [x86] prevent crashing while matching pmaddwd

This could crash in 2 ways: either one or both of
the input vectors could be a different size than
the math ops.

https://llvm.org/PR49716

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/madd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 74322f68912d..0eaba12330e7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49098,10 +49098,17 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
     SDValue N01In = N01Elt.getOperand(0);
     SDValue N10In = N10Elt.getOperand(0);
     SDValue N11In = N11Elt.getOperand(0);
+
     // First time we find an input capture it.
     if (!In0) {
       In0 = N00In;
       In1 = N01In;
+
+      // The input vector sizes must match the output.
+      // TODO: Insert cast ops to allow 
diff erent types.
+      if (In0.getValueSizeInBits() != VT.getSizeInBits() ||
+          In1.getValueSizeInBits() != VT.getSizeInBits())
+        return SDValue();
     }
     // Mul is commutative so the input vectors can be in any order.
     // Canonicalize to make the compares easier.

diff  --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll
index a024a04fa37f..c0ff6a79ef18 100644
--- a/llvm/test/CodeGen/X86/madd.ll
+++ b/llvm/test/CodeGen/X86/madd.ll
@@ -3046,3 +3046,104 @@ middle.block:
   %11 = extractelement <8 x i32> %bin.rdx34, i32 0
   ret i32 %11
 }
+
+; PR49716 - https://llvm.org/PR49716
+
+define <4 x i32> @input_size_mismatch(<16 x i16> %x, <16 x i16>* %p) {
+; SSE2-LABEL: input_size_mismatch:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movdqa (%rdi), %xmm1
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm2 = xmm0[0,2,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,6,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm0[3,1,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,7,5,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm0 = xmm0[1,0,3,2,4,5,6,7]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm3 = xmm1[0,2,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,4,6,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm3[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[3,1,2,3,4,5,6,7]
+; SSE2-NEXT:    pshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,7,5,6,7]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
+; SSE2-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[1,0,3,2,4,5,6,7]
+; SSE2-NEXT:    movdqa %xmm2, %xmm4
+; SSE2-NEXT:    pmulhw %xmm3, %xmm4
+; SSE2-NEXT:    pmullw %xmm3, %xmm2
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm4[0],xmm2[1],xmm4[1],xmm2[2],xmm4[2],xmm2[3],xmm4[3]
+; SSE2-NEXT:    movdqa %xmm0, %xmm3
+; SSE2-NEXT:    pmulhw %xmm1, %xmm3
+; SSE2-NEXT:    pmullw %xmm1, %xmm0
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm3[0],xmm0[1],xmm3[1],xmm0[2],xmm3[2],xmm0[3],xmm3[3]
+; SSE2-NEXT:    paddd %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; AVX-LABEL: input_size_mismatch:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa {{.*#+}} xmm1 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
+; AVX-NEXT:    vpshufb %xmm1, %xmm0, %xmm2
+; AVX-NEXT:    vmovdqa {{.*#+}} xmm3 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
+; AVX-NEXT:    vpshufb %xmm3, %xmm0, %xmm0
+; AVX-NEXT:    vmovdqa (%rdi), %xmm4
+; AVX-NEXT:    vpshufb %xmm1, %xmm4, %xmm1
+; AVX-NEXT:    vpshufb %xmm3, %xmm4, %xmm3
+; AVX-NEXT:    vpmovsxwd %xmm2, %xmm2
+; AVX-NEXT:    vpmovsxwd %xmm0, %xmm0
+; AVX-NEXT:    vpmovsxwd %xmm1, %xmm1
+; AVX-NEXT:    vpmulld %xmm1, %xmm2, %xmm1
+; AVX-NEXT:    vpmovsxwd %xmm3, %xmm2
+; AVX-NEXT:    vpmulld %xmm2, %xmm0, %xmm0
+; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    retq
+  %y = load <16 x i16>, <16 x i16>* %p, align 32
+  %x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %x1 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %y0 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %y1 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %sx0 = sext <4 x i16> %x0 to <4 x i32>
+  %sx1 = sext <4 x i16> %x1 to <4 x i32>
+  %sy0 = sext <4 x i16> %y0 to <4 x i32>
+  %sy1 = sext <4 x i16> %y1 to <4 x i32>
+  %m0 = mul <4 x i32> %sx0, %sy0
+  %m1 = mul <4 x i32> %sx1, %sy1
+  %r = add <4 x i32> %m0, %m1
+  ret <4 x i32> %r
+}
+
+define <4 x i32> @output_size_mismatch(<16 x i16> %x, <16 x i16> %y) {
+; SSE2-LABEL: output_size_mismatch:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    pmaddwd %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; AVX-LABEL: output_size_mismatch:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
+; AVX-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
+; AVX-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
+; AVX-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
+; AVX-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
+; AVX-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
+; AVX-NEXT:    vpmovsxwd %xmm3, %xmm3
+; AVX-NEXT:    vpmovsxwd %xmm0, %xmm0
+; AVX-NEXT:    vpmovsxwd %xmm2, %xmm2
+; AVX-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
+; AVX-NEXT:    vpmovsxwd %xmm1, %xmm1
+; AVX-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
+; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    retq
+  %x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %x1 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %y0 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %y1 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %sx0 = sext <4 x i16> %x0 to <4 x i32>
+  %sx1 = sext <4 x i16> %x1 to <4 x i32>
+  %sy0 = sext <4 x i16> %y0 to <4 x i32>
+  %sy1 = sext <4 x i16> %y1 to <4 x i32>
+  %m0 = mul <4 x i32> %sx0, %sy0
+  %m1 = mul <4 x i32> %sx1, %sy1
+  %r = add <4 x i32> %m0, %m1
+  ret <4 x i32> %r
+}


        


More information about the llvm-commits mailing list