[llvm] 5dfba56 - [DAG] Fold shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 16 07:46:52 PST 2021


Author: Simon Pilgrim
Date: 2021-02-16T15:46:34Z
New Revision: 5dfba562dd247f731528448ee83785b099f93629

URL: https://github.com/llvm/llvm-project/commit/5dfba562dd247f731528448ee83785b099f93629
DIFF: https://github.com/llvm/llvm-project/commit/5dfba562dd247f731528448ee83785b099f93629.diff

LOG: [DAG] Fold shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))

Fold shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))) -> bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))

Attempt to fold from a shuffle of a pair of binops to a binop of shuffles, as long as one/both of the binop sources are also shuffles that can be merged with the outer shuffle. This should guarantee that we remove one binop without introducing any additional shuffles.

Technically there's potential for a merged shuffle's lowering to be poorer than the original shuffle, but it could also be better, and I'm not seeing any regressions as long as we keep the 'don't merge splats' rule already present in MergeInnerShuffle.

This expands and generalizes an existing X86 combine and attempts to merge either of each binop's sources (with an on-the-fly commutation of the shuffle mask) - we couldn't do that in the x86 version as it had to stay in a form that DAGCombine's MergeInnerShuffle would still recognise.

Differential Revision: https://reviews.llvm.org/D96345

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/bitcast-and-setcc-256.ll
    llvm/test/CodeGen/X86/promote-cmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e3e79cbcbd9d..586b3b99f84a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -20919,11 +20919,13 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
 
   // Compute the combined shuffle mask for a shuffle with SV0 as the first
   // operand, and SV1 as the second operand.
-  // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask).
+  // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
+  //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
   auto MergeInnerShuffle =
-      [NumElts, &VT](ShuffleVectorSDNode *SVN, ShuffleVectorSDNode *OtherSVN,
-                     SDValue N1, const TargetLowering &TLI, SDValue &SV0,
-                     SDValue &SV1, SmallVectorImpl<int> &Mask) -> bool {
+      [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
+                     ShuffleVectorSDNode *OtherSVN, SDValue N1,
+                     const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
+                     SmallVectorImpl<int> &Mask) -> bool {
     // Don't try to fold splats; they're likely to simplify somehow, or they
     // might be free.
     if (OtherSVN->isSplat())
@@ -20940,6 +20942,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
         continue;
       }
 
+      if (Commute)
+        Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
+
       SDValue CurrentVec;
       if (Idx < (int)NumElts) {
         // This shuffle index refers to the inner shuffle N0. Lookup the inner
@@ -21045,7 +21050,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
 
     SDValue SV0, SV1;
     SmallVector<int, 4> Mask;
-    if (MergeInnerShuffle(SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) {
+    if (MergeInnerShuffle(false, SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) {
       // Check if all indices in Mask are Undef. In case, propagate Undef.
       if (llvm::all_of(Mask, [](int M) { return M < 0; }))
         return DAG.getUNDEF(VT);
@@ -21055,6 +21060,77 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
     }
   }
 
+  // Merge shuffles through binops if we are able to merge it with at least one
+  // other shuffles.
+  // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
+  if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
+    unsigned SrcOpcode = N0.getOpcode();
+    if (SrcOpcode == N1.getOpcode() && TLI.isBinOp(SrcOpcode) &&
+        N->isOnlyUserOf(N0.getNode()) && N->isOnlyUserOf(N1.getNode())) {
+      SDValue Op00 = N0.getOperand(0);
+      SDValue Op10 = N1.getOperand(0);
+      SDValue Op01 = N0.getOperand(1);
+      SDValue Op11 = N1.getOperand(1);
+      // TODO: We might be able to relax the VT check but we don't currently
+      // have any isBinOp() that has 
diff erent result/ops VTs so play safe until
+      // we have test coverage.
+      if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
+          Op01.getValueType() == VT && Op11.getValueType() == VT &&
+          (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
+           Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
+           Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
+           Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
+        auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
+                                        SmallVectorImpl<int> &Mask, bool LeftOp,
+                                        bool Commute) {
+          SDValue InnerN = Commute ? N1 : N0;
+          SDValue Op0 = LeftOp ? Op00 : Op01;
+          SDValue Op1 = LeftOp ? Op10 : Op11;
+          if (Commute)
+            std::swap(Op0, Op1);
+          return Op0.getOpcode() == ISD::VECTOR_SHUFFLE &&
+                 InnerN->isOnlyUserOf(Op0.getNode()) &&
+                 MergeInnerShuffle(Commute, SVN, cast<ShuffleVectorSDNode>(Op0),
+                                   Op1, TLI, SV0, SV1, Mask) &&
+                 llvm::none_of(Mask, [](int M) { return M < 0; });
+        };
+
+        // Ensure we don't increase the number of shuffles - we must merge a
+        // shuffle from at least one of the LHS and RHS ops.
+        bool MergedLeft = false;
+        SDValue LeftSV0, LeftSV1;
+        SmallVector<int, 4> LeftMask;
+        if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
+            CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
+          MergedLeft = true;
+        } else {
+          LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
+          LeftSV0 = Op00, LeftSV1 = Op10;
+        }
+
+        bool MergedRight = false;
+        SDValue RightSV0, RightSV1;
+        SmallVector<int, 4> RightMask;
+        if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
+            CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
+          MergedRight = true;
+        } else {
+          RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
+          RightSV0 = Op01, RightSV1 = Op11;
+        }
+
+        if (MergedLeft || MergedRight) {
+          SDLoc DL(N);
+          SDValue LHS =
+              DAG.getVectorShuffle(VT, DL, LeftSV0, LeftSV1, LeftMask);
+          SDValue RHS =
+              DAG.getVectorShuffle(VT, DL, RightSV0, RightSV1, RightMask);
+          return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
+        }
+      }
+    }
+  }
+
   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
     return V;
 

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2e5fb3579951..501274dc471d 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -38029,34 +38029,6 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
 
     if (SDValue HAddSub = foldShuffleOfHorizOp(N, DAG))
       return HAddSub;
-
-    // Merge shuffles through binops if its likely we'll be able to merge it
-    // with other shuffles (as long as they aren't splats).
-    // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
-    // TODO: We might be able to move this to DAGCombiner::visitVECTOR_SHUFFLE.
-    if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N)) {
-      unsigned SrcOpcode = N->getOperand(0).getOpcode();
-      if (SrcOpcode == N->getOperand(1).getOpcode() && TLI.isBinOp(SrcOpcode) &&
-          N->isOnlyUserOf(N->getOperand(0).getNode()) &&
-          N->isOnlyUserOf(N->getOperand(1).getNode())) {
-        SDValue Op00 = N->getOperand(0).getOperand(0);
-        SDValue Op10 = N->getOperand(1).getOperand(0);
-        SDValue Op01 = N->getOperand(0).getOperand(1);
-        SDValue Op11 = N->getOperand(1).getOperand(1);
-        auto *SVN00 = dyn_cast<ShuffleVectorSDNode>(Op00);
-        auto *SVN10 = dyn_cast<ShuffleVectorSDNode>(Op10);
-        auto *SVN01 = dyn_cast<ShuffleVectorSDNode>(Op01);
-        auto *SVN11 = dyn_cast<ShuffleVectorSDNode>(Op11);
-        if (((SVN00 && !SVN00->isSplat()) || (SVN10 && !SVN10->isSplat())) &&
-            ((SVN01 && !SVN01->isSplat()) || (SVN11 && !SVN11->isSplat()))) {
-          SDLoc DL(N);
-          ArrayRef<int> Mask = SVN->getMask();
-          SDValue LHS = DAG.getVectorShuffle(VT, DL, Op00, Op10, Mask);
-          SDValue RHS = DAG.getVectorShuffle(VT, DL, Op01, Op11, Mask);
-          return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
-        }
-      }
-    }
   }
 
   // Attempt to combine into a vector load/broadcast.

diff  --git a/llvm/test/CodeGen/X86/bitcast-and-setcc-256.ll b/llvm/test/CodeGen/X86/bitcast-and-setcc-256.ll
index aa8b123b713b..225ebaceeae6 100644
--- a/llvm/test/CodeGen/X86/bitcast-and-setcc-256.ll
+++ b/llvm/test/CodeGen/X86/bitcast-and-setcc-256.ll
@@ -9,47 +9,41 @@
 define i4 @v4i64(<4 x i64> %a, <4 x i64> %b, <4 x i64> %c, <4 x i64> %d) {
 ; SSE2-SSSE3-LABEL: v4i64:
 ; SSE2-SSSE3:       # %bb.0:
-; SSE2-SSSE3-NEXT:    movdqa {{.*#+}} xmm8 = [2147483648,2147483648]
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm3
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm1
-; SSE2-SSSE3-NEXT:    movdqa %xmm1, %xmm9
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm3, %xmm9
+; SSE2-SSSE3-NEXT:    movdqa {{.*#+}} xmm9 = [2147483648,2147483648]
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm3
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm1
+; SSE2-SSSE3-NEXT:    movdqa %xmm1, %xmm10
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm3, %xmm10
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm2
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm0
+; SSE2-SSSE3-NEXT:    movdqa %xmm0, %xmm8
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm2, %xmm8
+; SSE2-SSSE3-NEXT:    movdqa %xmm8, %xmm11
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm11 = xmm11[0,2],xmm10[0,2]
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm3, %xmm1
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm9, %xmm1
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm3 = xmm9[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm1, %xmm3
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm2
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm0
-; SSE2-SSSE3-NEXT:    movdqa %xmm0, %xmm1
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm2, %xmm1
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm2, %xmm0
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm2 = xmm0[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm1, %xmm2
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm2, %xmm0
-; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,2],xmm3[0,2]
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm7
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm5
-; SSE2-SSSE3-NEXT:    movdqa %xmm5, %xmm1
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm7, %xmm1
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,3],xmm1[1,3]
+; SSE2-SSSE3-NEXT:    andps %xmm11, %xmm0
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm8 = xmm8[1,3],xmm10[1,3]
+; SSE2-SSSE3-NEXT:    orps %xmm0, %xmm8
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm7
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm5
+; SSE2-SSSE3-NEXT:    movdqa %xmm5, %xmm0
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm7, %xmm0
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm6
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm4
+; SSE2-SSSE3-NEXT:    movdqa %xmm4, %xmm1
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm6, %xmm1
+; SSE2-SSSE3-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm2 = xmm2[0,2],xmm0[0,2]
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm7, %xmm5
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm2 = xmm5[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm1, %xmm2
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm2, %xmm1
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm6
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm4
-; SSE2-SSSE3-NEXT:    movdqa %xmm4, %xmm2
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm6, %xmm2
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm6, %xmm4
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm3 = xmm4[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm2, %xmm3
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm3, %xmm2
-; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm2 = xmm2[0,2],xmm1[0,2]
-; SSE2-SSSE3-NEXT:    andps %xmm0, %xmm2
-; SSE2-SSSE3-NEXT:    movmskps %xmm2, %eax
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm4 = xmm4[1,3],xmm5[1,3]
+; SSE2-SSSE3-NEXT:    andps %xmm2, %xmm4
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm1 = xmm1[1,3],xmm0[1,3]
+; SSE2-SSSE3-NEXT:    orps %xmm4, %xmm1
+; SSE2-SSSE3-NEXT:    andps %xmm8, %xmm1
+; SSE2-SSSE3-NEXT:    movmskps %xmm1, %eax
 ; SSE2-SSSE3-NEXT:    # kill: def $al killed $al killed $eax
 ; SSE2-SSSE3-NEXT:    retq
 ;

diff  --git a/llvm/test/CodeGen/X86/promote-cmp.ll b/llvm/test/CodeGen/X86/promote-cmp.ll
index ecf475e01eb3..c59f808a3029 100644
--- a/llvm/test/CodeGen/X86/promote-cmp.ll
+++ b/llvm/test/CodeGen/X86/promote-cmp.ll
@@ -8,36 +8,34 @@ define <4 x i64> @PR45808(<4 x i64> %0, <4 x i64> %1) {
 ; SSE2-LABEL: PR45808:
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    movdqa {{.*#+}} xmm4 = [2147483648,2147483648]
-; SSE2-NEXT:    movdqa %xmm3, %xmm5
-; SSE2-NEXT:    pxor %xmm4, %xmm5
+; SSE2-NEXT:    movdqa %xmm3, %xmm9
+; SSE2-NEXT:    pxor %xmm4, %xmm9
 ; SSE2-NEXT:    movdqa %xmm1, %xmm6
 ; SSE2-NEXT:    pxor %xmm4, %xmm6
-; SSE2-NEXT:    movdqa %xmm6, %xmm7
-; SSE2-NEXT:    pcmpgtd %xmm5, %xmm7
-; SSE2-NEXT:    pcmpeqd %xmm5, %xmm6
-; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm6[1,1,3,3]
-; SSE2-NEXT:    pand %xmm7, %xmm5
-; SSE2-NEXT:    pshufd {{.*#+}} xmm6 = xmm7[1,1,3,3]
-; SSE2-NEXT:    por %xmm5, %xmm6
-; SSE2-NEXT:    movdqa %xmm2, %xmm5
-; SSE2-NEXT:    pxor %xmm4, %xmm5
+; SSE2-NEXT:    movdqa %xmm6, %xmm8
+; SSE2-NEXT:    pcmpgtd %xmm9, %xmm8
+; SSE2-NEXT:    movdqa %xmm2, %xmm7
+; SSE2-NEXT:    pxor %xmm4, %xmm7
 ; SSE2-NEXT:    pxor %xmm0, %xmm4
-; SSE2-NEXT:    movdqa %xmm4, %xmm7
-; SSE2-NEXT:    pcmpgtd %xmm5, %xmm7
-; SSE2-NEXT:    pcmpeqd %xmm5, %xmm4
+; SSE2-NEXT:    movdqa %xmm4, %xmm5
+; SSE2-NEXT:    pcmpgtd %xmm7, %xmm5
+; SSE2-NEXT:    movdqa %xmm5, %xmm10
+; SSE2-NEXT:    shufps {{.*#+}} xmm10 = xmm10[0,2],xmm8[0,2]
+; SSE2-NEXT:    pcmpeqd %xmm9, %xmm6
+; SSE2-NEXT:    pcmpeqd %xmm7, %xmm4
+; SSE2-NEXT:    shufps {{.*#+}} xmm4 = xmm4[1,3],xmm6[1,3]
+; SSE2-NEXT:    andps %xmm10, %xmm4
+; SSE2-NEXT:    shufps {{.*#+}} xmm5 = xmm5[1,3],xmm8[1,3]
+; SSE2-NEXT:    orps %xmm4, %xmm5
+; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm5[2,1,3,3]
+; SSE2-NEXT:    pxor {{.*}}(%rip), %xmm5
+; SSE2-NEXT:    psllq $63, %xmm4
+; SSE2-NEXT:    psrad $31, %xmm4
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[1,1,3,3]
-; SSE2-NEXT:    pand %xmm7, %xmm4
-; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm7[1,1,3,3]
-; SSE2-NEXT:    por %xmm4, %xmm5
-; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm5[0,2,2,3]
-; SSE2-NEXT:    pxor {{.*}}(%rip), %xmm4
-; SSE2-NEXT:    psllq $63, %xmm6
-; SSE2-NEXT:    psrad $31, %xmm6
-; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm6[1,1,3,3]
-; SSE2-NEXT:    pand %xmm5, %xmm1
-; SSE2-NEXT:    pandn %xmm3, %xmm5
-; SSE2-NEXT:    por %xmm5, %xmm1
-; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm4[0,1,1,3]
+; SSE2-NEXT:    pand %xmm4, %xmm1
+; SSE2-NEXT:    pandn %xmm3, %xmm4
+; SSE2-NEXT:    por %xmm4, %xmm1
+; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm5[0,1,1,3]
 ; SSE2-NEXT:    psllq $63, %xmm3
 ; SSE2-NEXT:    psrad $31, %xmm3
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm3[1,1,3,3]


        


More information about the llvm-commits mailing list