[llvm] d29ccbe - [X86][AVX] Attempt to fold a scaled index into a gather/scatter scale immediate (PR13310)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 28 06:08:03 PDT 2021


Author: Simon Pilgrim
Date: 2021-10-28T14:07:17+01:00
New Revision: d29ccbecd093c881c599fd41db5d68dae744f91f

URL: https://github.com/llvm/llvm-project/commit/d29ccbecd093c881c599fd41db5d68dae744f91f
DIFF: https://github.com/llvm/llvm-project/commit/d29ccbecd093c881c599fd41db5d68dae744f91f.diff

LOG: [X86][AVX] Attempt to fold a scaled index into a gather/scatter scale immediate (PR13310)

If the index operand for a gather/scatter intrinsic is being scaled (self-addition or a shl-by-immediate) then we may be able to fold that scaling into the intrinsic scale immediate value instead.

Fixes PR13310.

Differential Revision: https://reviews.llvm.org/D108539

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/masked_gather_scatter.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 15eec7a69726..e922cb356dfe 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -50227,9 +50227,40 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
 }
 
 static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
-                                       TargetLowering::DAGCombinerInfo &DCI) {
+                                       TargetLowering::DAGCombinerInfo &DCI,
+                                       const X86Subtarget &Subtarget) {
+  auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N);
+  SDValue Index = MemOp->getIndex();
+  SDValue Scale = MemOp->getScale();
+  SDValue Mask = MemOp->getMask();
+
+  // Attempt to fold an index scale into the scale value directly.
+  // TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively?
+  if ((Index.getOpcode() == X86ISD::VSHLI ||
+       (Index.getOpcode() == ISD::ADD &&
+        Index.getOperand(0) == Index.getOperand(1))) &&
+      isa<ConstantSDNode>(Scale)) {
+    unsigned ShiftAmt =
+        Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1);
+    uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
+    uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt);
+    if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
+      SDValue NewIndex = Index.getOperand(0);
+      SDValue NewScale =
+          DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType());
+      if (N->getOpcode() == X86ISD::MGATHER)
+        return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG,
+                                 MemOp->getOperand(1), Mask,
+                                 MemOp->getBasePtr(), NewIndex, NewScale,
+                                 MemOp->getChain(), Subtarget);
+      if (N->getOpcode() == X86ISD::MSCATTER)
+        return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG,
+                              MemOp->getOperand(1), Mask, MemOp->getBasePtr(),
+                              NewIndex, NewScale, MemOp->getChain(), Subtarget);
+    }
+  }
+
   // With vector masks we only demand the upper bit of the mask.
-  SDValue Mask = cast<X86MaskedGatherScatterSDNode>(N)->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
@@ -52886,7 +52917,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case X86ISD::FMSUBADD:    return combineFMADDSUB(N, DAG, DCI);
   case X86ISD::MOVMSK:      return combineMOVMSK(N, DAG, DCI, Subtarget);
   case X86ISD::MGATHER:
-  case X86ISD::MSCATTER:    return combineX86GatherScatter(N, DAG, DCI);
+  case X86ISD::MSCATTER:
+    return combineX86GatherScatter(N, DAG, DCI, Subtarget);
   case ISD::MGATHER:
   case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
   case X86ISD::PCMPEQ:

diff  --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index d6c3f8625ffe..fbe02af64e3d 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -808,20 +808,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) {
 ; KNL_64-NEXT:    vmovd %esi, %xmm0
 ; KNL_64-NEXT:    vpbroadcastd %xmm0, %ymm0
 ; KNL_64-NEXT:    vpmovsxdq %ymm0, %zmm0
-; KNL_64-NEXT:    vpsllq $2, %zmm0, %zmm0
 ; KNL_64-NEXT:    kxnorw %k0, %k0, %k1
 ; KNL_64-NEXT:    vxorps %xmm1, %xmm1, %xmm1
-; KNL_64-NEXT:    vgatherqps (%rax,%zmm0), %ymm1 {%k1}
+; KNL_64-NEXT:    vgatherqps (%rax,%zmm0,4), %ymm1 {%k1}
 ; KNL_64-NEXT:    vinsertf64x4 $1, %ymm1, %zmm1, %zmm0
 ; KNL_64-NEXT:    retq
 ;
 ; KNL_32-LABEL: test14:
 ; KNL_32:       # %bb.0:
 ; KNL_32-NEXT:    vmovd %xmm0, %eax
-; KNL_32-NEXT:    vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1
+; KNL_32-NEXT:    vbroadcastss {{[0-9]+}}(%esp), %zmm1
 ; KNL_32-NEXT:    kxnorw %k0, %k0, %k1
 ; KNL_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; KNL_32-NEXT:    vgatherdps (%eax,%zmm1), %zmm0 {%k1}
+; KNL_32-NEXT:    vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
 ; KNL_32-NEXT:    retl
 ;
 ; SKX-LABEL: test14:
@@ -829,20 +828,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) {
 ; SKX-NEXT:    vmovq %xmm0, %rax
 ; SKX-NEXT:    vpbroadcastd %esi, %ymm0
 ; SKX-NEXT:    vpmovsxdq %ymm0, %zmm0
-; SKX-NEXT:    vpsllq $2, %zmm0, %zmm0
 ; SKX-NEXT:    kxnorw %k0, %k0, %k1
 ; SKX-NEXT:    vxorps %xmm1, %xmm1, %xmm1
-; SKX-NEXT:    vgatherqps (%rax,%zmm0), %ymm1 {%k1}
+; SKX-NEXT:    vgatherqps (%rax,%zmm0,4), %ymm1 {%k1}
 ; SKX-NEXT:    vinsertf64x4 $1, %ymm1, %zmm1, %zmm0
 ; SKX-NEXT:    retq
 ;
 ; SKX_32-LABEL: test14:
 ; SKX_32:       # %bb.0:
 ; SKX_32-NEXT:    vmovd %xmm0, %eax
-; SKX_32-NEXT:    vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1
+; SKX_32-NEXT:    vbroadcastss {{[0-9]+}}(%esp), %zmm1
 ; SKX_32-NEXT:    kxnorw %k0, %k0, %k1
 ; SKX_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; SKX_32-NEXT:    vgatherdps (%eax,%zmm1), %zmm0 {%k1}
+; SKX_32-NEXT:    vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
 ; SKX_32-NEXT:    retl
 
   %broadcast.splatinsert = insertelement <16 x float*> %vec, float* %base, i32 1
@@ -4988,38 +4986,38 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
 
 ;
 ; PR13310
-; FIXME: Failure to fold scaled-index into gather/scatter scale operand.
+; Failure to fold scaled-index into gather/scatter scale operand.
 ;
 
 define <8 x float> @scaleidx_x86gather(float* %base, <8 x i32> %index, <8 x i32> %imask) nounwind {
 ; KNL_64-LABEL: scaleidx_x86gather:
 ; KNL_64:       # %bb.0:
-; KNL_64-NEXT:    vpslld $2, %ymm0, %ymm2
-; KNL_64-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; KNL_64-NEXT:    vgatherdps %ymm1, (%rdi,%ymm2), %ymm0
+; KNL_64-NEXT:    vxorps %xmm2, %xmm2, %xmm2
+; KNL_64-NEXT:    vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2
+; KNL_64-NEXT:    vmovaps %ymm2, %ymm0
 ; KNL_64-NEXT:    retq
 ;
 ; KNL_32-LABEL: scaleidx_x86gather:
 ; KNL_32:       # %bb.0:
 ; KNL_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; KNL_32-NEXT:    vpslld $2, %ymm0, %ymm2
-; KNL_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; KNL_32-NEXT:    vgatherdps %ymm1, (%eax,%ymm2), %ymm0
+; KNL_32-NEXT:    vxorps %xmm2, %xmm2, %xmm2
+; KNL_32-NEXT:    vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2
+; KNL_32-NEXT:    vmovaps %ymm2, %ymm0
 ; KNL_32-NEXT:    retl
 ;
 ; SKX-LABEL: scaleidx_x86gather:
 ; SKX:       # %bb.0:
-; SKX-NEXT:    vpslld $2, %ymm0, %ymm2
-; SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; SKX-NEXT:    vgatherdps %ymm1, (%rdi,%ymm2), %ymm0
+; SKX-NEXT:    vxorps %xmm2, %xmm2, %xmm2
+; SKX-NEXT:    vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2
+; SKX-NEXT:    vmovaps %ymm2, %ymm0
 ; SKX-NEXT:    retq
 ;
 ; SKX_32-LABEL: scaleidx_x86gather:
 ; SKX_32:       # %bb.0:
 ; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; SKX_32-NEXT:    vpslld $2, %ymm0, %ymm2
-; SKX_32-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; SKX_32-NEXT:    vgatherdps %ymm1, (%eax,%ymm2), %ymm0
+; SKX_32-NEXT:    vxorps %xmm2, %xmm2, %xmm2
+; SKX_32-NEXT:    vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2
+; SKX_32-NEXT:    vmovaps %ymm2, %ymm0
 ; SKX_32-NEXT:    retl
   %ptr = bitcast float* %base to i8*
   %mask = bitcast <8 x i32> %imask to <8 x float>
@@ -5070,8 +5068,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
 ; KNL_64-LABEL: scaleidx_x86scatter:
 ; KNL_64:       # %bb.0:
 ; KNL_64-NEXT:    kmovw %esi, %k1
-; KNL_64-NEXT:    vpaddd %zmm1, %zmm1, %zmm1
-; KNL_64-NEXT:    vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1}
+; KNL_64-NEXT:    vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1}
 ; KNL_64-NEXT:    vzeroupper
 ; KNL_64-NEXT:    retq
 ;
@@ -5079,16 +5076,14 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
 ; KNL_32:       # %bb.0:
 ; KNL_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; KNL_32-NEXT:    kmovw {{[0-9]+}}(%esp), %k1
-; KNL_32-NEXT:    vpaddd %zmm1, %zmm1, %zmm1
-; KNL_32-NEXT:    vscatterdps %zmm0, (%eax,%zmm1,2) {%k1}
+; KNL_32-NEXT:    vscatterdps %zmm0, (%eax,%zmm1,4) {%k1}
 ; KNL_32-NEXT:    vzeroupper
 ; KNL_32-NEXT:    retl
 ;
 ; SKX-LABEL: scaleidx_x86scatter:
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    kmovw %esi, %k1
-; SKX-NEXT:    vpaddd %zmm1, %zmm1, %zmm1
-; SKX-NEXT:    vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1}
+; SKX-NEXT:    vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1}
 ; SKX-NEXT:    vzeroupper
 ; SKX-NEXT:    retq
 ;
@@ -5096,8 +5091,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
 ; SKX_32:       # %bb.0:
 ; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; SKX_32-NEXT:    kmovw {{[0-9]+}}(%esp), %k1
-; SKX_32-NEXT:    vpaddd %zmm1, %zmm1, %zmm1
-; SKX_32-NEXT:    vscatterdps %zmm0, (%eax,%zmm1,2) {%k1}
+; SKX_32-NEXT:    vscatterdps %zmm0, (%eax,%zmm1,4) {%k1}
 ; SKX_32-NEXT:    vzeroupper
 ; SKX_32-NEXT:    retl
   %ptr = bitcast float* %base to i8*
@@ -5135,18 +5129,16 @@ define void @scaleidx_scatter(<8 x float> %value, float* %base, <8 x i32> %index
 ;
 ; SKX-LABEL: scaleidx_scatter:
 ; SKX:       # %bb.0:
-; SKX-NEXT:    vpaddd %ymm1, %ymm1, %ymm1
 ; SKX-NEXT:    kmovw %esi, %k1
-; SKX-NEXT:    vscatterdps %ymm0, (%rdi,%ymm1,4) {%k1}
+; SKX-NEXT:    vscatterdps %ymm0, (%rdi,%ymm1,8) {%k1}
 ; SKX-NEXT:    vzeroupper
 ; SKX-NEXT:    retq
 ;
 ; SKX_32-LABEL: scaleidx_scatter:
 ; SKX_32:       # %bb.0:
 ; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; SKX_32-NEXT:    vpaddd %ymm1, %ymm1, %ymm1
 ; SKX_32-NEXT:    kmovb {{[0-9]+}}(%esp), %k1
-; SKX_32-NEXT:    vscatterdps %ymm0, (%eax,%ymm1,4) {%k1}
+; SKX_32-NEXT:    vscatterdps %ymm0, (%eax,%ymm1,8) {%k1}
 ; SKX_32-NEXT:    vzeroupper
 ; SKX_32-NEXT:    retl
   %scaledindex = mul <8 x i32> %index, <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>


        


More information about the llvm-commits mailing list