[llvm] 1e7c1dd - [SDAG] avoid crash from mismatched types in scalar-to-vector fold
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 28 06:16:10 PDT 2022
Author: Sanjay Patel
Date: 2022-10-28T09:14:08-04:00
New Revision: 1e7c1dd67cd63a6b14d5d4bd8e0e195e9a910f7b
URL: https://github.com/llvm/llvm-project/commit/1e7c1dd67cd63a6b14d5d4bd8e0e195e9a910f7b
DIFF: https://github.com/llvm/llvm-project/commit/1e7c1dd67cd63a6b14d5d4bd8e0e195e9a910f7b.diff
LOG: [SDAG] avoid crash from mismatched types in scalar-to-vector fold
This bug was introduced with D136713 / 54eeadcf442df91aed0 .
As an enhancement, we could cast operands to the expected type,
but we need to make sure that is done correctly (zext vs. sext).
It's also possible (but seems unlikely) that an operand can have
a type larger than the result type.
Fixes #58661
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/X86/vec_shift5.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index beed155ee645..c402c2872afd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23518,8 +23518,11 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
// TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
SDValue Scalar = N->getOperand(0);
unsigned Opcode = Scalar.getOpcode();
+ EVT VecEltVT = VT.getScalarType();
if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
- TLI.isBinOp(Opcode) && VT.getScalarType() == Scalar.getValueType() &&
+ TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
+ Scalar.getOperand(0).getValueType() == VecEltVT &&
+ Scalar.getOperand(1).getValueType() == VecEltVT &&
DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
// Match an extract element and get a shuffle mask equivalent.
SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
@@ -23564,11 +23567,9 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
return SDValue();
// If we have an implicit truncate, truncate here if it is legal.
- if (VT.getScalarType() != Scalar.getValueType() &&
- Scalar.getValueType().isScalarInteger() &&
- isTypeLegal(VT.getScalarType())) {
- SDValue Val =
- DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VT.getScalarType(), Scalar);
+ if (VecEltVT != Scalar.getValueType() &&
+ Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
+ SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
}
@@ -23580,7 +23581,7 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
EVT SrcVT = SrcVec.getValueType();
unsigned SrcNumElts = SrcVT.getVectorNumElements();
unsigned VTNumElts = VT.getVectorNumElements();
- if (VT.getScalarType() == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
+ if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
// Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
SmallVector<int, 8> Mask(SrcNumElts, -1);
Mask[0] = ExtIndexC->getZExtValue();
diff --git a/llvm/test/CodeGen/X86/vec_shift5.ll b/llvm/test/CodeGen/X86/vec_shift5.ll
index 429cfd83681c..ab16e1a60946 100644
--- a/llvm/test/CodeGen/X86/vec_shift5.ll
+++ b/llvm/test/CodeGen/X86/vec_shift5.ll
@@ -291,6 +291,23 @@ define <4 x i32> @extelt0_twice_sub_pslli_v4i32(<4 x i32> %x, <4 x i32> %y, <4 x
ret <4 x i32> %r
}
+; This would crash because the scalar shift amount has a
diff erent type than the shift result.
+
+define <2 x i8> @PR58661(<2 x i8> %a0) {
+; CHECK-LABEL: PR58661:
+; CHECK: # %bb.0:
+; CHECK-NEXT: psrlw $8, %xmm0
+; CHECK-NEXT: movd %xmm0, %eax
+; CHECK-NEXT: shll $8, %eax
+; CHECK-NEXT: movd %eax, %xmm0
+; CHECK-NEXT: ret{{[l|q]}}
+ %shuffle = shufflevector <2 x i8> %a0, <2 x i8> <i8 poison, i8 0>, <2 x i32> <i32 1, i32 3>
+ %x = bitcast <2 x i8> %shuffle to i16
+ %shl = shl nuw i16 %x, 8
+ %y = bitcast i16 %shl to <2 x i8>
+ ret <2 x i8> %y
+}
+
declare <8 x i16> @llvm.x86.sse2.pslli.w(<8 x i16>, i32)
declare <8 x i16> @llvm.x86.sse2.psrli.w(<8 x i16>, i32)
declare <8 x i16> @llvm.x86.sse2.psrai.w(<8 x i16>, i32)
More information about the llvm-commits
mailing list