[llvm] 3ae0a40 - [X86] combineHorizOpWithShuffle - peek through one use bitcasts when decoding shuffles.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 9 02:51:18 PDT 2021


Author: Simon Pilgrim
Date: 2021-04-09T10:51:04+01:00
New Revision: 3ae0a405fc94d1b7a0ced15742031e8d71b32d93

URL: https://github.com/llvm/llvm-project/commit/3ae0a405fc94d1b7a0ced15742031e8d71b32d93
DIFF: https://github.com/llvm/llvm-project/commit/3ae0a405fc94d1b7a0ced15742031e8d71b32d93.diff

LOG: [X86] combineHorizOpWithShuffle - peek through one use bitcasts when decoding shuffles.

Checking for one use, peek through bitcasts of the horizop args to allows us to merge shuffles of different widths through the horizop.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/horizontal-sum.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6abfc6716c5c..7c0ec182865c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -43335,13 +43335,16 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
   unsigned Opcode = N->getOpcode();
   assert(isHorizOp(Opcode) && "Unexpected hadd/hsub/pack opcode");
 
+  SDLoc DL(N);
   EVT VT = N->getValueType(0);
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   EVT SrcVT = N0.getValueType();
 
-  SDValue BC0 = peekThroughBitcasts(N0);
-  SDValue BC1 = peekThroughBitcasts(N1);
+  SDValue BC0 =
+      N->isOnlyUserOf(N0.getNode()) ? peekThroughOneUseBitcasts(N0) : N0;
+  SDValue BC1 =
+      N->isOnlyUserOf(N1.getNode()) ? peekThroughOneUseBitcasts(N1) : N1;
 
   // Attempt to fold HOP(LOSUBVECTOR(SHUFFLE(X)),HISUBVECTOR(SHUFFLE(X)))
   // to SHUFFLE(HOP(LOSUBVECTOR(X),HISUBVECTOR(X))), this is mainly for
@@ -43366,7 +43369,6 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
       if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
           ShuffleOps[0].getValueType().is256BitVector() &&
           scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
-        SDLoc DL(N);
         SDValue Lo, Hi;
         MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
         std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
@@ -43381,7 +43383,6 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
   }
 
   // Attempt to fold HOP(SHUFFLE(X),SHUFFLE(Y)) -> SHUFFLE(HOP(X,Y)).
-  // TODO: Merge with binary shuffle folds below.
   if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) {
     int PostShuffle[4] = {0, 1, 2, 3};
 
@@ -43395,22 +43396,20 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
 
       resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
       if (isAnyZero(ShuffleMask) || ShuffleOps.size() != 1 ||
-          !ShuffleOps[0].getValueType().is128BitVector() ||
-          !N->isOnlyUserOf(V.getNode()) ||
+          !ShuffleOps[0].getValueType().is128BitVector() || !V->hasOneUse() ||
           !scaleShuffleElements(ShuffleMask, 2, ScaledMask))
         return SDValue();
 
       PostShuffle[Offset + 0] = ScaledMask[0] < 0 ? -1 : Offset + ScaledMask[0];
       PostShuffle[Offset + 1] = ScaledMask[1] < 0 ? -1 : Offset + ScaledMask[1];
-      return DAG.getBitcast(V.getValueType(), ShuffleOps[0]);
+      return ShuffleOps[0];
     };
 
-    SDValue Src0 = AdjustOp(N0, 0);
-    SDValue Src1 = AdjustOp(N1, 2);
+    SDValue Src0 = AdjustOp(BC0, 0);
+    SDValue Src1 = AdjustOp(BC1, 2);
     if (Src0 || Src1) {
-      Src0 = Src0 ? Src0 : N0;
-      Src1 = Src1 ? Src1 : N1;
-      SDLoc DL(N);
+      Src0 = DAG.getBitcast(SrcVT, Src0 ? Src0 : BC0);
+      Src1 = DAG.getBitcast(SrcVT, Src1 ? Src1 : BC1);
       MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
       SDValue Res = DAG.getNode(Opcode, DL, VT, Src0, Src1);
       Res = DAG.getBitcast(ShufVT, Res);
@@ -43420,35 +43419,37 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
   }
 
   // Attempt to fold HOP(SHUFFLE(X,Y),SHUFFLE(X,Y)) -> SHUFFLE(HOP(X,Y)).
-  // TODO: Relax shuffle scaling to support sub-128-bit subvector shuffles.
   if (VT.is256BitVector() && Subtarget.hasInt256()) {
     SmallVector<int> Mask0, Mask1;
     SmallVector<SDValue> Ops0, Ops1;
-    if (getTargetShuffleInputs(N0, Ops0, Mask0, DAG) && !isAnyZero(Mask0) &&
-        getTargetShuffleInputs(N1, Ops1, Mask1, DAG) && !isAnyZero(Mask1) &&
-        !Ops0.empty() && !Ops1.empty()) {
-      SDValue Op00 = Ops0.front(), Op01 = Ops0.back();
-      SDValue Op10 = Ops1.front(), Op11 = Ops1.back();
-      SmallVector<int, 2> ShuffleMask0, ShuffleMask1;
-      if (Op00.getValueType() == SrcVT && Op01.getValueType() == SrcVT &&
-          Op10.getValueType() == SrcVT && Op11.getValueType() == SrcVT &&
-          scaleShuffleElements(Mask0, 2, ShuffleMask0) &&
-          scaleShuffleElements(Mask1, 2, ShuffleMask1)) {
-        if ((Op00 == Op11) && (Op01 == Op10)) {
-          std::swap(Op10, Op11);
-          ShuffleVectorSDNode::commuteMask(ShuffleMask1);
-        }
-        if ((Op00 == Op10) && (Op01 == Op11)) {
-          SmallVector<int, 4> ShuffleMask;
-          ShuffleMask.append(ShuffleMask0.begin(), ShuffleMask0.end());
-          ShuffleMask.append(ShuffleMask1.begin(), ShuffleMask1.end());
-          SDLoc DL(N);
-          MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f64 : MVT::v4i64;
-          SDValue Res = DAG.getNode(Opcode, DL, VT, Op00, Op01);
-          Res = DAG.getBitcast(ShufVT, Res);
-          Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ShuffleMask);
-          return DAG.getBitcast(VT, Res);
-        }
+    SmallVector<int, 2> ScaledMask0, ScaledMask1;
+    if (getTargetShuffleInputs(BC0, Ops0, Mask0, DAG) && !isAnyZero(Mask0) &&
+        getTargetShuffleInputs(BC1, Ops1, Mask1, DAG) && !isAnyZero(Mask1) &&
+        !Ops0.empty() && !Ops1.empty() &&
+        all_of(Ops0,
+               [](SDValue Op) { return Op.getValueType().is256BitVector(); }) &&
+        all_of(Ops1,
+               [](SDValue Op) { return Op.getValueType().is256BitVector(); }) &&
+        scaleShuffleElements(Mask0, 2, ScaledMask0) &&
+        scaleShuffleElements(Mask1, 2, ScaledMask1)) {
+      SDValue Op00 = peekThroughBitcasts(Ops0.front());
+      SDValue Op10 = peekThroughBitcasts(Ops1.front());
+      SDValue Op01 = peekThroughBitcasts(Ops0.back());
+      SDValue Op11 = peekThroughBitcasts(Ops1.back());
+      if ((Op00 == Op11) && (Op01 == Op10)) {
+        std::swap(Op10, Op11);
+        ShuffleVectorSDNode::commuteMask(ScaledMask1);
+      }
+      if ((Op00 == Op10) && (Op01 == Op11)) {
+        SmallVector<int, 4> ShuffleMask;
+        ShuffleMask.append(ScaledMask0.begin(), ScaledMask0.end());
+        ShuffleMask.append(ScaledMask1.begin(), ScaledMask1.end());
+        MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f64 : MVT::v4i64;
+        SDValue Res = DAG.getNode(Opcode, DL, VT, DAG.getBitcast(SrcVT, Op00),
+                                  DAG.getBitcast(SrcVT, Op01));
+        Res = DAG.getBitcast(ShufVT, Res);
+        Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ShuffleMask);
+        return DAG.getBitcast(VT, Res);
       }
     }
   }

diff  --git a/llvm/test/CodeGen/X86/horizontal-sum.ll b/llvm/test/CodeGen/X86/horizontal-sum.ll
index d7c367c331c8..8c50dd2e108f 100644
--- a/llvm/test/CodeGen/X86/horizontal-sum.ll
+++ b/llvm/test/CodeGen/X86/horizontal-sum.ll
@@ -196,10 +196,8 @@ define <8 x float> @pair_sum_v8f32_v4f32(<4 x float> %0, <4 x float> %1, <4 x fl
 ; SSSE3-SLOW-NEXT:    addps %xmm1, %xmm3
 ; SSSE3-SLOW-NEXT:    movlhps {{.*#+}} xmm0 = xmm0[0],xmm3[0]
 ; SSSE3-SLOW-NEXT:    haddps %xmm7, %xmm6
-; SSSE3-SLOW-NEXT:    movaps %xmm6, %xmm1
-; SSSE3-SLOW-NEXT:    unpckhpd {{.*#+}} xmm1 = xmm1[1],xmm6[1]
-; SSSE3-SLOW-NEXT:    haddps %xmm1, %xmm6
-; SSSE3-SLOW-NEXT:    shufps {{.*#+}} xmm3 = xmm3[2,3],xmm6[0,2]
+; SSSE3-SLOW-NEXT:    haddps %xmm6, %xmm6
+; SSSE3-SLOW-NEXT:    shufps {{.*#+}} xmm3 = xmm3[2,3],xmm6[0,3]
 ; SSSE3-SLOW-NEXT:    movaps %xmm3, %xmm1
 ; SSSE3-SLOW-NEXT:    retq
 ;
@@ -241,9 +239,7 @@ define <8 x float> @pair_sum_v8f32_v4f32(<4 x float> %0, <4 x float> %1, <4 x fl
 ; AVX1-SLOW-NEXT:    vpermilpd {{.*#+}} xmm1 = xmm1[1,0]
 ; AVX1-SLOW-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm1
 ; AVX1-SLOW-NEXT:    vhaddps %xmm7, %xmm6, %xmm2
-; AVX1-SLOW-NEXT:    vpermilpd {{.*#+}} xmm3 = xmm2[1,1]
-; AVX1-SLOW-NEXT:    vhaddps %xmm3, %xmm2, %xmm2
-; AVX1-SLOW-NEXT:    vpermilps {{.*#+}} xmm2 = xmm2[0,2,2,3]
+; AVX1-SLOW-NEXT:    vhaddps %xmm2, %xmm2, %xmm2
 ; AVX1-SLOW-NEXT:    vinsertf128 $1, %xmm2, %ymm0, %ymm0
 ; AVX1-SLOW-NEXT:    vshufpd {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2],ymm0[2]
 ; AVX1-SLOW-NEXT:    retq
@@ -291,9 +287,7 @@ define <8 x float> @pair_sum_v8f32_v4f32(<4 x float> %0, <4 x float> %1, <4 x fl
 ; AVX2-SLOW-NEXT:    vpermilpd {{.*#+}} xmm1 = xmm1[1,0]
 ; AVX2-SLOW-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm1
 ; AVX2-SLOW-NEXT:    vhaddps %xmm7, %xmm6, %xmm2
-; AVX2-SLOW-NEXT:    vpermilpd {{.*#+}} xmm3 = xmm2[1,1]
-; AVX2-SLOW-NEXT:    vhaddps %xmm3, %xmm2, %xmm2
-; AVX2-SLOW-NEXT:    vpermilps {{.*#+}} xmm2 = xmm2[0,2,2,3]
+; AVX2-SLOW-NEXT:    vhaddps %xmm2, %xmm2, %xmm2
 ; AVX2-SLOW-NEXT:    vinsertf128 $1, %xmm2, %ymm0, %ymm0
 ; AVX2-SLOW-NEXT:    vshufpd {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2],ymm0[2]
 ; AVX2-SLOW-NEXT:    retq


        


More information about the llvm-commits mailing list