[llvm] r318452 - [X86] Pre-truncate gather/scatter indices that have element sizes larger than 64-bits before Legalize.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 16 12:23:22 PST 2017


Author: ctopper
Date: Thu Nov 16 12:23:22 2017
New Revision: 318452

URL: http://llvm.org/viewvc/llvm-project?rev=318452&view=rev
Log:
[X86] Pre-truncate gather/scatter indices that have element sizes larger than 64-bits before Legalize.

The wider element type will normally cause legalize to try to split and scalarize the gather/scatter, but we can't handle that. Instead, truncate the index early so the gather/scatter node is insulated from the legalization.

This really shouldn't happen in practice since InstCombine will normalize index types to the same size as pointers.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=318452&r1=318451&r2=318452&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Thu Nov 16 12:23:22 2017
@@ -35829,8 +35829,25 @@ static SDValue combineSetCC(SDNode *N, S
   return SDValue();
 }
 
-static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG) {
+static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
+                                    TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
+
+  // Pre-shrink oversized index elements to avoid triggering scalarization.
+  if (DCI.isBeforeLegalize()) {
+    SDValue Index = N->getOperand(4);
+    if (Index.getValueType().getScalarSizeInBits() > 64) {
+      EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), MVT::i64,
+                                   Index.getValueType().getVectorNumElements());
+      SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, IndexVT, Index);
+      SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
+      NewOps[4] = Trunc;
+      DAG.UpdateNodeOperands(N, NewOps);
+      DCI.AddToWorklist(N);
+      return SDValue(N, 0);
+    }
+  }
+
   // Gather and Scatter instructions use k-registers for masks. The type of
   // the masks is v*i1. So the mask will be truncated anyway.
   // The SIGN_EXTEND_INREG my be dropped.
@@ -36949,7 +36966,7 @@ SDValue X86TargetLowering::PerformDAGCom
   case X86ISD::FMADDSUB:
   case X86ISD::FMSUBADD:    return combineFMADDSUB(N, DAG, Subtarget);
   case ISD::MGATHER:
-  case ISD::MSCATTER:       return combineGatherScatter(N, DAG);
+  case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
   case X86ISD::TESTM:       return combineTestM(N, DAG, Subtarget);
   case X86ISD::PCMPEQ:
   case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);

Modified: llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll?rev=318452&r1=318451&r2=318452&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll (original)
+++ llvm/trunk/test/CodeGen/X86/masked_gather_scatter.ll Thu Nov 16 12:23:22 2017
@@ -2476,3 +2476,66 @@ define <1 x i32> @v1_gather(<1 x i32*> %
   ret <1 x i32>%res
 }
 declare <1 x i32> @llvm.masked.gather.v1i32.v1p0i32(<1 x i32*>, i32, <1 x i1>, <1 x i32>)
+
+; Make sure we don't crash when the index element type is larger than i64 and we need to widen the result
+; This experienced a bad interaction when we widened and then tried to split.
+define <2 x float> @large_index(float* %base, <2 x i128> %ind, <2 x i1> %mask, <2 x float> %src0) {
+; KNL_64-LABEL: large_index:
+; KNL_64:       # BB#0:
+; KNL_64-NEXT:    # kill: %XMM1<def> %XMM1<kill> %YMM1<def>
+; KNL_64-NEXT:    vinsertps {{.*#+}} xmm0 = xmm0[0,2],zero,zero
+; KNL_64-NEXT:    vmovaps %xmm0, %xmm0
+; KNL_64-NEXT:    vmovq %rcx, %xmm2
+; KNL_64-NEXT:    vmovq %rsi, %xmm3
+; KNL_64-NEXT:    vpunpcklqdq {{.*#+}} xmm2 = xmm3[0],xmm2[0]
+; KNL_64-NEXT:    vpslld $31, %ymm0, %ymm0
+; KNL_64-NEXT:    vptestmd %zmm0, %zmm0, %k1
+; KNL_64-NEXT:    vgatherqps (%rdi,%zmm2,4), %ymm1 {%k1}
+; KNL_64-NEXT:    vmovaps %xmm1, %xmm0
+; KNL_64-NEXT:    vzeroupper
+; KNL_64-NEXT:    retq
+;
+; KNL_32-LABEL: large_index:
+; KNL_32:       # BB#0:
+; KNL_32-NEXT:    # kill: %XMM1<def> %XMM1<kill> %YMM1<def>
+; KNL_32-NEXT:    vinsertps {{.*#+}} xmm0 = xmm0[0,2],zero,zero
+; KNL_32-NEXT:    vmovaps %xmm0, %xmm0
+; KNL_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; KNL_32-NEXT:    vmovd {{.*#+}} xmm2 = mem[0],zero,zero,zero
+; KNL_32-NEXT:    vpinsrd $1, {{[0-9]+}}(%esp), %xmm2, %xmm2
+; KNL_32-NEXT:    vpinsrd $2, {{[0-9]+}}(%esp), %xmm2, %xmm2
+; KNL_32-NEXT:    vpinsrd $3, {{[0-9]+}}(%esp), %xmm2, %xmm2
+; KNL_32-NEXT:    vpslld $31, %ymm0, %ymm0
+; KNL_32-NEXT:    vptestmd %zmm0, %zmm0, %k1
+; KNL_32-NEXT:    vgatherqps (%eax,%zmm2,4), %ymm1 {%k1}
+; KNL_32-NEXT:    vmovaps %xmm1, %xmm0
+; KNL_32-NEXT:    vzeroupper
+; KNL_32-NEXT:    retl
+;
+; SKX-LABEL: large_index:
+; SKX:       # BB#0:
+; SKX-NEXT:    vpsllq $63, %xmm0, %xmm0
+; SKX-NEXT:    vptestmq %xmm0, %xmm0, %k1
+; SKX-NEXT:    vmovq %rcx, %xmm0
+; SKX-NEXT:    vmovq %rsi, %xmm2
+; SKX-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm2[0],xmm0[0]
+; SKX-NEXT:    vgatherqps (%rdi,%xmm0,4), %xmm1 {%k1}
+; SKX-NEXT:    vmovaps %xmm1, %xmm0
+; SKX-NEXT:    retq
+;
+; SKX_32-LABEL: large_index:
+; SKX_32:       # BB#0:
+; SKX_32-NEXT:    vpsllq $63, %xmm0, %xmm0
+; SKX_32-NEXT:    vptestmq %xmm0, %xmm0, %k1
+; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; SKX_32-NEXT:    vmovd {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; SKX_32-NEXT:    vpinsrd $1, {{[0-9]+}}(%esp), %xmm0, %xmm0
+; SKX_32-NEXT:    vpinsrd $2, {{[0-9]+}}(%esp), %xmm0, %xmm0
+; SKX_32-NEXT:    vpinsrd $3, {{[0-9]+}}(%esp), %xmm0, %xmm0
+; SKX_32-NEXT:    vgatherqps (%eax,%xmm0,4), %xmm1 {%k1}
+; SKX_32-NEXT:    vmovaps %xmm1, %xmm0
+; SKX_32-NEXT:    retl
+  %gep.random = getelementptr float, float* %base, <2 x i128> %ind
+  %res = call <2 x float> @llvm.masked.gather.v2f32.v2p0f32(<2 x float*> %gep.random, i32 4, <2 x i1> %mask, <2 x float> %src0)
+  ret <2 x float>%res
+}




More information about the llvm-commits mailing list