[llvm] r356043 - [X86][AVX] lowerShuffleAsBroadcast - improve load folding by avoiding bitcasts

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 13 05:20:39 PDT 2019


Author: rksimon
Date: Wed Mar 13 05:20:39 2019
New Revision: 356043

URL: http://llvm.org/viewvc/llvm-project?rev=356043&view=rev
Log:
[X86][AVX] lowerShuffleAsBroadcast - improve load folding by avoiding bitcasts

AVX1 broadcasts were failing as we were adding bitcasts that caused MayFoldLoad's hasOneUse to return false.

This patch stops introducing bitcasts so early and also replaces the broadcast index scaling through bitcasts (which can't succeed in some cases) to instead just keep track of the bitoffset which can be converted back to the broadcast index later on.

Differential Revision: https://reviews.llvm.org/D58888

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/widened-broadcast.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=356043&r1=356042&r2=356043&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Mar 13 05:20:39 2019
@@ -11930,6 +11930,7 @@ static SDValue lowerShuffleAsBroadcast(c
   // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
   // we can only broadcast from a register with AVX2.
   unsigned NumElts = Mask.size();
+  unsigned NumEltBits = VT.getScalarSizeInBits();
   unsigned Opcode = (VT == MVT::v2f64 && !Subtarget.hasAVX2())
                         ? X86ISD::MOVDDUP
                         : X86ISD::VBROADCAST;
@@ -11953,29 +11954,19 @@ static SDValue lowerShuffleAsBroadcast(c
 
   // Go up the chain of (vector) values to find a scalar load that we can
   // combine with the broadcast.
+  int BitOffset = BroadcastIdx * NumEltBits;
   SDValue V = V1;
   for (;;) {
     switch (V.getOpcode()) {
     case ISD::BITCAST: {
-      // Peek through bitcasts as long as BroadcastIdx can be adjusted.
-      SDValue VSrc = V.getOperand(0);
-      unsigned NumEltBits = V.getScalarValueSizeInBits();
-      unsigned NumSrcBits = VSrc.getScalarValueSizeInBits();
-      if ((NumEltBits % NumSrcBits) == 0)
-        BroadcastIdx *= (NumEltBits / NumSrcBits);
-      else if ((NumSrcBits % NumEltBits) == 0 &&
-               (BroadcastIdx % (NumSrcBits / NumEltBits)) == 0)
-        BroadcastIdx /= (NumSrcBits / NumEltBits);
-      else
-        break;
-      V = VSrc;
+      V = V.getOperand(0);
       continue;
     }
     case ISD::CONCAT_VECTORS: {
-      int OperandSize =
-          V.getOperand(0).getSimpleValueType().getVectorNumElements();
-      V = V.getOperand(BroadcastIdx / OperandSize);
-      BroadcastIdx %= OperandSize;
+      int OpBitWidth = V.getOperand(0).getValueSizeInBits();
+      int OpIdx = BitOffset / OpBitWidth;
+      V = V.getOperand(OpIdx);
+      BitOffset %= OpBitWidth;
       continue;
     }
     case ISD::INSERT_SUBVECTOR: {
@@ -11984,11 +11975,13 @@ static SDValue lowerShuffleAsBroadcast(c
       if (!ConstantIdx)
         break;
 
-      int BeginIdx = (int)ConstantIdx->getZExtValue();
-      int EndIdx =
-          BeginIdx + (int)VInner.getSimpleValueType().getVectorNumElements();
-      if (BroadcastIdx >= BeginIdx && BroadcastIdx < EndIdx) {
-        BroadcastIdx -= BeginIdx;
+      int EltBitWidth = VOuter.getScalarValueSizeInBits();
+      int Idx = (int)ConstantIdx->getZExtValue();
+      int NumSubElts = (int)VInner.getSimpleValueType().getVectorNumElements();
+      int BeginOffset = Idx * EltBitWidth;
+      int EndOffset = BeginOffset + NumSubElts * EltBitWidth;
+      if (BeginOffset <= BitOffset && BitOffset < EndOffset) {
+        BitOffset -= BeginOffset;
         V = VInner;
       } else {
         V = VOuter;
@@ -11998,48 +11991,34 @@ static SDValue lowerShuffleAsBroadcast(c
     }
     break;
   }
+  assert((BitOffset % NumEltBits) == 0 && "Illegal bit-offset");
+  BroadcastIdx = BitOffset / NumEltBits;
 
-  // Ensure the source vector and BroadcastIdx are for a suitable type.
-  if (VT.getScalarSizeInBits() != V.getScalarValueSizeInBits()) {
-    unsigned NumEltBits = VT.getScalarSizeInBits();
-    unsigned NumSrcBits = V.getScalarValueSizeInBits();
-    if ((NumSrcBits % NumEltBits) == 0)
-      BroadcastIdx *= (NumSrcBits / NumEltBits);
-    else if ((NumEltBits % NumSrcBits) == 0 &&
-             (BroadcastIdx % (NumEltBits / NumSrcBits)) == 0)
-      BroadcastIdx /= (NumEltBits / NumSrcBits);
-    else
-      return SDValue();
-
-    unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
-    MVT SrcVT = MVT::getVectorVT(VT.getScalarType(), NumSrcElts);
-    V = DAG.getBitcast(SrcVT, V);
-  }
+  // Do we need to bitcast the source to retrieve the original broadcast index?
+  bool BitCastSrc = V.getScalarValueSizeInBits() != NumEltBits;
 
   // Check if this is a broadcast of a scalar. We special case lowering
   // for scalars so that we can more effectively fold with loads.
-  // First, look through bitcast: if the original value has a larger element
-  // type than the shuffle, the broadcast element is in essence truncated.
-  // Make that explicit to ease folding.
-  if (V.getOpcode() == ISD::BITCAST && VT.isInteger())
+  // If the original value has a larger element type than the shuffle, the
+  // broadcast element is in essence truncated. Make that explicit to ease
+  // folding.
+  if (BitCastSrc && VT.isInteger())
     if (SDValue TruncBroadcast = lowerShuffleAsTruncBroadcast(
-            DL, VT, V.getOperand(0), BroadcastIdx, Subtarget, DAG))
+            DL, VT, V, BroadcastIdx, Subtarget, DAG))
       return TruncBroadcast;
 
   MVT BroadcastVT = VT;
 
-  // Peek through any bitcast (only useful for loads).
-  SDValue BC = peekThroughBitcasts(V);
-
   // Also check the simpler case, where we can directly reuse the scalar.
-  if ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) ||
-      (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0)) {
+  if (!BitCastSrc &&
+      ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) ||
+       (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0))) {
     V = V.getOperand(BroadcastIdx);
 
     // If we can't broadcast from a register, check that the input is a load.
     if (!BroadcastFromReg && !isShuffleFoldableLoad(V))
       return SDValue();
-  } else if (MayFoldLoad(BC) && !cast<LoadSDNode>(BC)->isVolatile()) {
+  } else if (MayFoldLoad(V) && !cast<LoadSDNode>(V)->isVolatile()) {
     // 32-bit targets need to load i64 as a f64 and then bitcast the result.
     if (!Subtarget.is64Bit() && VT.getScalarType() == MVT::i64) {
       BroadcastVT = MVT::getVectorVT(MVT::f64, VT.getVectorNumElements());
@@ -12050,10 +12029,11 @@ static SDValue lowerShuffleAsBroadcast(c
 
     // If we are broadcasting a load that is only used by the shuffle
     // then we can reduce the vector load to the broadcasted scalar load.
-    LoadSDNode *Ld = cast<LoadSDNode>(BC);
+    LoadSDNode *Ld = cast<LoadSDNode>(V);
     SDValue BaseAddr = Ld->getOperand(1);
     EVT SVT = BroadcastVT.getScalarType();
     unsigned Offset = BroadcastIdx * SVT.getStoreSize();
+    assert((Offset * 8) == BitOffset && "Unexpected bit-offset");
     SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL);
     V = DAG.getLoad(SVT, DL, Ld->getChain(), NewAddr,
                     DAG.getMachineFunction().getMachineMemOperand(
@@ -12062,7 +12042,7 @@ static SDValue lowerShuffleAsBroadcast(c
   } else if (!BroadcastFromReg) {
     // We can't broadcast from a vector register.
     return SDValue();
-  } else if (BroadcastIdx != 0) {
+  } else if (BitOffset != 0) {
     // We can only broadcast from the zero-element of a vector register,
     // but it can be advantageous to broadcast from the zero-element of a
     // subvector.
@@ -12074,18 +12054,15 @@ static SDValue lowerShuffleAsBroadcast(c
       return SDValue();
 
     // Only broadcast the zero-element of a 128-bit subvector.
-    unsigned EltSize = VT.getScalarSizeInBits();
-    if (((BroadcastIdx * EltSize) % 128) != 0)
+    if ((BitOffset % 128) != 0)
       return SDValue();
 
-    // The shuffle input might have been a bitcast we looked through; look at
-    // the original input vector.  Emit an EXTRACT_SUBVECTOR of that type; we'll
-    // later bitcast it to BroadcastVT.
-    assert(V.getScalarValueSizeInBits() == BroadcastVT.getScalarSizeInBits() &&
-           "Unexpected vector element size");
+    assert((BitOffset % V.getScalarValueSizeInBits()) == 0 &&
+           "Unexpected bit-offset");
     assert((V.getValueSizeInBits() == 256 || V.getValueSizeInBits() == 512) &&
            "Unexpected vector size");
-    V = extract128BitVector(V, BroadcastIdx, DAG, DL);
+    unsigned ExtractIdx = BitOffset / V.getScalarValueSizeInBits();
+    V = extract128BitVector(V, ExtractIdx, DAG, DL);
   }
 
   if (Opcode == X86ISD::MOVDDUP && !V.getValueType().isVector())
@@ -12093,21 +12070,21 @@ static SDValue lowerShuffleAsBroadcast(c
                     DAG.getBitcast(MVT::f64, V));
 
   // Bitcast back to the same scalar type as BroadcastVT.
-  MVT SrcVT = V.getSimpleValueType();
-  if (SrcVT.getScalarType() != BroadcastVT.getScalarType()) {
-    assert(SrcVT.getScalarSizeInBits() == BroadcastVT.getScalarSizeInBits() &&
+  if (V.getValueType().getScalarType() != BroadcastVT.getScalarType()) {
+    assert(NumEltBits == BroadcastVT.getScalarSizeInBits() &&
            "Unexpected vector element size");
-    if (SrcVT.isVector()) {
-      unsigned NumSrcElts = SrcVT.getVectorNumElements();
-      SrcVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts);
+    MVT ExtVT;
+    if (V.getValueType().isVector()) {
+      unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
+      ExtVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts);
     } else {
-      SrcVT = BroadcastVT.getScalarType();
+      ExtVT = BroadcastVT.getScalarType();
     }
-    V = DAG.getBitcast(SrcVT, V);
+    V = DAG.getBitcast(ExtVT, V);
   }
 
   // 32-bit targets need to load i64 as a f64 and then bitcast the result.
-  if (!Subtarget.is64Bit() && SrcVT == MVT::i64) {
+  if (!Subtarget.is64Bit() && V.getValueType() == MVT::i64) {
     V = DAG.getBitcast(MVT::f64, V);
     unsigned NumBroadcastElts = BroadcastVT.getVectorNumElements();
     BroadcastVT = MVT::getVectorVT(MVT::f64, NumBroadcastElts);
@@ -12116,9 +12093,9 @@ static SDValue lowerShuffleAsBroadcast(c
   // We only support broadcasting from 128-bit vectors to minimize the
   // number of patterns we need to deal with in isel. So extract down to
   // 128-bits, removing as many bitcasts as possible.
-  if (SrcVT.getSizeInBits() > 128) {
-    MVT ExtVT = MVT::getVectorVT(SrcVT.getScalarType(),
-                                 128 / SrcVT.getScalarSizeInBits());
+  if (V.getValueSizeInBits() > 128) {
+    MVT ExtVT = V.getSimpleValueType().getScalarType();
+    ExtVT = MVT::getVectorVT(ExtVT, 128 / ExtVT.getScalarSizeInBits());
     V = extract128BitVector(peekThroughBitcasts(V), 0, DAG, DL);
     V = DAG.getBitcast(ExtVT, V);
   }

Modified: llvm/trunk/test/CodeGen/X86/widened-broadcast.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/widened-broadcast.ll?rev=356043&r1=356042&r2=356043&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/widened-broadcast.ll (original)
+++ llvm/trunk/test/CodeGen/X86/widened-broadcast.ll Wed Mar 13 05:20:39 2019
@@ -110,21 +110,10 @@ define <8 x i32> @load_splat_8i32_4i32_0
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_8i32_4i32_01010101:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_8i32_4i32_01010101:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_8i32_4i32_01010101:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_8i32_4i32_01010101:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <4 x i32>, <4 x i32>* %ptr
   %ret = shufflevector <4 x i32> %ld, <4 x i32> undef, <8 x i32> <i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1>
@@ -207,21 +196,10 @@ define <16 x i16> @load_splat_16i16_8i16
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_16i16_8i16_0101010101010101:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,0,0,0]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_16i16_8i16_0101010101010101:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_16i16_8i16_0101010101010101:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_16i16_8i16_0101010101010101:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastss (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <8 x i16>, <8 x i16>* %ptr
   %ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> <i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1>
@@ -235,21 +213,10 @@ define <16 x i16> @load_splat_16i16_8i16
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_16i16_8i16_0123012301230123:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_16i16_8i16_0123012301230123:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_16i16_8i16_0123012301230123:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_16i16_8i16_0123012301230123:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <8 x i16>, <8 x i16>* %ptr
   %ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3,i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
@@ -407,21 +374,10 @@ define <32 x i8> @load_splat_32i8_16i8_0
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,0,0,0]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastss (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <16 x i8>, <16 x i8>* %ptr
   %ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
@@ -435,21 +391,10 @@ define <32 x i8> @load_splat_32i8_16i8_0
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <16 x i8>, <16 x i8>* %ptr
   %ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>




More information about the llvm-commits mailing list