[llvm] 1e7c1dd - [SDAG] avoid crash from mismatched types in scalar-to-vector fold

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 28 06:16:10 PDT 2022


Author: Sanjay Patel
Date: 2022-10-28T09:14:08-04:00
New Revision: 1e7c1dd67cd63a6b14d5d4bd8e0e195e9a910f7b

URL: https://github.com/llvm/llvm-project/commit/1e7c1dd67cd63a6b14d5d4bd8e0e195e9a910f7b
DIFF: https://github.com/llvm/llvm-project/commit/1e7c1dd67cd63a6b14d5d4bd8e0e195e9a910f7b.diff

LOG: [SDAG] avoid crash from mismatched types in scalar-to-vector fold

This bug was introduced with D136713 / 54eeadcf442df91aed0 .

As an enhancement, we could cast operands to the expected type,
but we need to make sure that is done correctly (zext vs. sext).
It's also possible (but seems unlikely) that an operand can have
a type larger than the result type.

Fixes #58661

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/vec_shift5.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index beed155ee645..c402c2872afd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23518,8 +23518,11 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
   // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
   SDValue Scalar = N->getOperand(0);
   unsigned Opcode = Scalar.getOpcode();
+  EVT VecEltVT = VT.getScalarType();
   if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
-      TLI.isBinOp(Opcode) && VT.getScalarType() == Scalar.getValueType() &&
+      TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
+      Scalar.getOperand(0).getValueType() == VecEltVT &&
+      Scalar.getOperand(1).getValueType() == VecEltVT &&
       DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
     // Match an extract element and get a shuffle mask equivalent.
     SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
@@ -23564,11 +23567,9 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
     return SDValue();
 
   // If we have an implicit truncate, truncate here if it is legal.
-  if (VT.getScalarType() != Scalar.getValueType() &&
-      Scalar.getValueType().isScalarInteger() &&
-      isTypeLegal(VT.getScalarType())) {
-    SDValue Val =
-        DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VT.getScalarType(), Scalar);
+  if (VecEltVT != Scalar.getValueType() &&
+      Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
+    SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
     return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
   }
 
@@ -23580,7 +23581,7 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
   EVT SrcVT = SrcVec.getValueType();
   unsigned SrcNumElts = SrcVT.getVectorNumElements();
   unsigned VTNumElts = VT.getVectorNumElements();
-  if (VT.getScalarType() == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
+  if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
     // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
     SmallVector<int, 8> Mask(SrcNumElts, -1);
     Mask[0] = ExtIndexC->getZExtValue();

diff  --git a/llvm/test/CodeGen/X86/vec_shift5.ll b/llvm/test/CodeGen/X86/vec_shift5.ll
index 429cfd83681c..ab16e1a60946 100644
--- a/llvm/test/CodeGen/X86/vec_shift5.ll
+++ b/llvm/test/CodeGen/X86/vec_shift5.ll
@@ -291,6 +291,23 @@ define <4 x i32> @extelt0_twice_sub_pslli_v4i32(<4 x i32> %x, <4 x i32> %y, <4 x
   ret <4 x i32> %r
 }
 
+; This would crash because the scalar shift amount has a 
diff erent type than the shift result.
+
+define <2 x i8> @PR58661(<2 x i8> %a0) {
+; CHECK-LABEL: PR58661:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    psrlw $8, %xmm0
+; CHECK-NEXT:    movd %xmm0, %eax
+; CHECK-NEXT:    shll $8, %eax
+; CHECK-NEXT:    movd %eax, %xmm0
+; CHECK-NEXT:    ret{{[l|q]}}
+  %shuffle = shufflevector <2 x i8> %a0, <2 x i8> <i8 poison, i8 0>, <2 x i32> <i32 1, i32 3>
+  %x = bitcast <2 x i8> %shuffle to i16
+  %shl = shl nuw i16 %x, 8
+  %y = bitcast i16 %shl to <2 x i8>
+  ret <2 x i8> %y
+}
+
 declare <8 x i16> @llvm.x86.sse2.pslli.w(<8 x i16>, i32)
 declare <8 x i16> @llvm.x86.sse2.psrli.w(<8 x i16>, i32)
 declare <8 x i16> @llvm.x86.sse2.psrai.w(<8 x i16>, i32)


        


More information about the llvm-commits mailing list