[llvm] a283d72 - [x86] prevent crashing while matching pmaddwd
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 27 02:32:41 PDT 2021
Author: Sanjay Patel
Date: 2021-03-27T05:27:14-04:00
New Revision: a283d725836033f5d7626470506160b7bf6d9107
URL: https://github.com/llvm/llvm-project/commit/a283d725836033f5d7626470506160b7bf6d9107
DIFF: https://github.com/llvm/llvm-project/commit/a283d725836033f5d7626470506160b7bf6d9107.diff
LOG: [x86] prevent crashing while matching pmaddwd
This could crash in 2 ways: either one or both of
the input vectors could be a different size than
the math ops.
https://llvm.org/PR49716
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/madd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 74322f68912d..0eaba12330e7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49098,10 +49098,17 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
SDValue N01In = N01Elt.getOperand(0);
SDValue N10In = N10Elt.getOperand(0);
SDValue N11In = N11Elt.getOperand(0);
+
// First time we find an input capture it.
if (!In0) {
In0 = N00In;
In1 = N01In;
+
+ // The input vector sizes must match the output.
+ // TODO: Insert cast ops to allow
diff erent types.
+ if (In0.getValueSizeInBits() != VT.getSizeInBits() ||
+ In1.getValueSizeInBits() != VT.getSizeInBits())
+ return SDValue();
}
// Mul is commutative so the input vectors can be in any order.
// Canonicalize to make the compares easier.
diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll
index a024a04fa37f..c0ff6a79ef18 100644
--- a/llvm/test/CodeGen/X86/madd.ll
+++ b/llvm/test/CodeGen/X86/madd.ll
@@ -3046,3 +3046,104 @@ middle.block:
%11 = extractelement <8 x i32> %bin.rdx34, i32 0
ret i32 %11
}
+
+; PR49716 - https://llvm.org/PR49716
+
+define <4 x i32> @input_size_mismatch(<16 x i16> %x, <16 x i16>* %p) {
+; SSE2-LABEL: input_size_mismatch:
+; SSE2: # %bb.0:
+; SSE2-NEXT: movdqa (%rdi), %xmm1
+; SSE2-NEXT: pshuflw {{.*#+}} xmm2 = xmm0[0,2,2,3,4,5,6,7]
+; SSE2-NEXT: pshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,6,6,7]
+; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3]
+; SSE2-NEXT: pshuflw {{.*#+}} xmm0 = xmm0[3,1,2,3,4,5,6,7]
+; SSE2-NEXT: pshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,7,5,6,7]
+; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; SSE2-NEXT: pshuflw {{.*#+}} xmm0 = xmm0[1,0,3,2,4,5,6,7]
+; SSE2-NEXT: pshuflw {{.*#+}} xmm3 = xmm1[0,2,2,3,4,5,6,7]
+; SSE2-NEXT: pshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,4,6,6,7]
+; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm3[0,2,2,3]
+; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[3,1,2,3,4,5,6,7]
+; SSE2-NEXT: pshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,7,5,6,7]
+; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
+; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[1,0,3,2,4,5,6,7]
+; SSE2-NEXT: movdqa %xmm2, %xmm4
+; SSE2-NEXT: pmulhw %xmm3, %xmm4
+; SSE2-NEXT: pmullw %xmm3, %xmm2
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm4[0],xmm2[1],xmm4[1],xmm2[2],xmm4[2],xmm2[3],xmm4[3]
+; SSE2-NEXT: movdqa %xmm0, %xmm3
+; SSE2-NEXT: pmulhw %xmm1, %xmm3
+; SSE2-NEXT: pmullw %xmm1, %xmm0
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm3[0],xmm0[1],xmm3[1],xmm0[2],xmm3[2],xmm0[3],xmm3[3]
+; SSE2-NEXT: paddd %xmm2, %xmm0
+; SSE2-NEXT: retq
+;
+; AVX-LABEL: input_size_mismatch:
+; AVX: # %bb.0:
+; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
+; AVX-NEXT: vpshufb %xmm1, %xmm0, %xmm2
+; AVX-NEXT: vmovdqa {{.*#+}} xmm3 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
+; AVX-NEXT: vpshufb %xmm3, %xmm0, %xmm0
+; AVX-NEXT: vmovdqa (%rdi), %xmm4
+; AVX-NEXT: vpshufb %xmm1, %xmm4, %xmm1
+; AVX-NEXT: vpshufb %xmm3, %xmm4, %xmm3
+; AVX-NEXT: vpmovsxwd %xmm2, %xmm2
+; AVX-NEXT: vpmovsxwd %xmm0, %xmm0
+; AVX-NEXT: vpmovsxwd %xmm1, %xmm1
+; AVX-NEXT: vpmulld %xmm1, %xmm2, %xmm1
+; AVX-NEXT: vpmovsxwd %xmm3, %xmm2
+; AVX-NEXT: vpmulld %xmm2, %xmm0, %xmm0
+; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0
+; AVX-NEXT: vzeroupper
+; AVX-NEXT: retq
+ %y = load <16 x i16>, <16 x i16>* %p, align 32
+ %x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %x1 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %y0 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %y1 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %sx0 = sext <4 x i16> %x0 to <4 x i32>
+ %sx1 = sext <4 x i16> %x1 to <4 x i32>
+ %sy0 = sext <4 x i16> %y0 to <4 x i32>
+ %sy1 = sext <4 x i16> %y1 to <4 x i32>
+ %m0 = mul <4 x i32> %sx0, %sy0
+ %m1 = mul <4 x i32> %sx1, %sy1
+ %r = add <4 x i32> %m0, %m1
+ ret <4 x i32> %r
+}
+
+define <4 x i32> @output_size_mismatch(<16 x i16> %x, <16 x i16> %y) {
+; SSE2-LABEL: output_size_mismatch:
+; SSE2: # %bb.0:
+; SSE2-NEXT: pmaddwd %xmm2, %xmm0
+; SSE2-NEXT: retq
+;
+; AVX-LABEL: output_size_mismatch:
+; AVX: # %bb.0:
+; AVX-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
+; AVX-NEXT: vpshufb %xmm2, %xmm0, %xmm3
+; AVX-NEXT: vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
+; AVX-NEXT: vpshufb %xmm4, %xmm0, %xmm0
+; AVX-NEXT: vpshufb %xmm2, %xmm1, %xmm2
+; AVX-NEXT: vpshufb %xmm4, %xmm1, %xmm1
+; AVX-NEXT: vpmovsxwd %xmm3, %xmm3
+; AVX-NEXT: vpmovsxwd %xmm0, %xmm0
+; AVX-NEXT: vpmovsxwd %xmm2, %xmm2
+; AVX-NEXT: vpmulld %xmm2, %xmm3, %xmm2
+; AVX-NEXT: vpmovsxwd %xmm1, %xmm1
+; AVX-NEXT: vpmulld %xmm1, %xmm0, %xmm0
+; AVX-NEXT: vpaddd %xmm0, %xmm2, %xmm0
+; AVX-NEXT: vzeroupper
+; AVX-NEXT: retq
+ %x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %x1 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %y0 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %y1 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %sx0 = sext <4 x i16> %x0 to <4 x i32>
+ %sx1 = sext <4 x i16> %x1 to <4 x i32>
+ %sy0 = sext <4 x i16> %y0 to <4 x i32>
+ %sy1 = sext <4 x i16> %y1 to <4 x i32>
+ %m0 = mul <4 x i32> %sx0, %sy0
+ %m1 = mul <4 x i32> %sx1, %sy1
+ %r = add <4 x i32> %m0, %m1
+ ret <4 x i32> %r
+}
More information about the llvm-commits
mailing list