[llvm] 1337821 - [DAGCombiner][X86] Fold a CONCAT_VECTORS of SHUFFLE_VECTOR and it's operand into wider SHUFFLE_VECTOR

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 1 12:19:32 PST 2023


Author: Roman Lebedev
Date: 2023-01-01T23:18:42+03:00
New Revision: 1337821f11902e219fe7720494879799219f2fc5

URL: https://github.com/llvm/llvm-project/commit/1337821f11902e219fe7720494879799219f2fc5
DIFF: https://github.com/llvm/llvm-project/commit/1337821f11902e219fe7720494879799219f2fc5.diff

LOG: [DAGCombiner][X86] Fold a CONCAT_VECTORS of SHUFFLE_VECTOR and it's operand into wider SHUFFLE_VECTOR

This was showing as a source of *many* regressions
with more aggressive ZERO_EXTEND_VECTOR_INREG recognition.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ee2918e419404..9235042c54a4a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -21891,6 +21891,109 @@ static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
 }
 
+// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
+// the operands is a SHUFFLE_VECTOR, and all other operands are also operands
+// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
+static SDValue combineConcatVectorOfShuffleAndItsOperands(
+    SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
+    bool LegalOperations) {
+  EVT VT = N->getValueType(0);
+  EVT OpVT = N->getOperand(0).getValueType();
+  if (VT.isScalableVector())
+    return SDValue();
+
+  // For now, only allow simple 2-operand concatenations.
+  if (N->getNumOperands() != 2)
+    return SDValue();
+
+  // Don't create illegal types/shuffles when not allowed to.
+  if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
+      (LegalOperations &&
+       !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
+    return SDValue();
+
+  // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
+  // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
+  // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
+  // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
+  // (4) and for now, the SHUFFLE_VECTOR must be unary.
+  ShuffleVectorSDNode *SVN = nullptr;
+  for (SDValue Op : N->ops()) {
+    if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
+        CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
+        all_of(N->ops(), [CurSVN](SDValue Op) {
+          // FIXME: can we allow UNDEF operands?
+          return !Op.isUndef() &&
+                 (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
+        })) {
+      SVN = CurSVN;
+      break;
+    }
+  }
+  if (!SVN)
+    return SDValue();
+
+  // We are going to pad the shuffle operands, so any indice, that was picking
+  // from the second operand, must be adjusted.
+  SmallVector<int, 16> AdjustedMask;
+  AdjustedMask.reserve(SVN->getMask().size());
+  assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
+  append_range(AdjustedMask, SVN->getMask());
+
+  // Identity masks for the operands of the (padded) shuffle.
+  SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
+  MutableArrayRef<int> FirstShufOpIdentityMask =
+      MutableArrayRef<int>(IdentityMask)
+          .take_front(OpVT.getVectorNumElements());
+  MutableArrayRef<int> SecondShufOpIdentityMask =
+      MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
+  std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
+  std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
+            VT.getVectorNumElements());
+
+  // New combined shuffle mask.
+  SmallVector<int, 32> Mask;
+  Mask.reserve(VT.getVectorNumElements());
+  for (SDValue Op : N->ops()) {
+    assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
+    if (Op.getNode() == SVN) {
+      append_range(Mask, AdjustedMask);
+      continue;
+    }
+    if (Op == SVN->getOperand(0)) {
+      append_range(Mask, FirstShufOpIdentityMask);
+      continue;
+    }
+    if (Op == SVN->getOperand(1)) {
+      append_range(Mask, SecondShufOpIdentityMask);
+      continue;
+    }
+    llvm_unreachable("Unexpected operand!");
+  }
+
+  // Don't create illegal shuffle masks.
+  if (!TLI.isShuffleMaskLegal(Mask, VT))
+    return SDValue();
+
+  // Pad the shuffle operands with UNDEF.
+  SDLoc dl(N);
+  std::array<SDValue, 2> ShufOps;
+  for (auto I : zip(SVN->ops(), ShufOps)) {
+    SDValue ShufOp = std::get<0>(I);
+    SDValue &NewShufOp = std::get<1>(I);
+    if (ShufOp.isUndef())
+      NewShufOp = DAG.getUNDEF(VT);
+    else {
+      SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
+                                          DAG.getUNDEF(OpVT));
+      ShufOpParts[0] = ShufOp;
+      NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
+    }
+  }
+  // Finally, create the new wide shuffle.
+  return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
+}
+
 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
   // If we only have one input vector, we don't need to do any concatenation.
   if (N->getNumOperands() == 1)
@@ -22026,6 +22129,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
     return V;
 
+  if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
+          N, DAG, TLI, LegalTypes, LegalOperations))
+    return V;
+
   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
   // operands and look for a CONCAT operations that place the incoming vectors

diff  --git a/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll b/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll
index 4ffe97e6de236..52fc059cc6818 100644
--- a/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll
+++ b/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll
@@ -23,32 +23,33 @@ define void @concat_a_to_shuf_of_a(ptr %a.ptr, ptr %dst) {
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vmovaps (%rdi), %xmm0
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX-NEXT:    vmovaps %xmm1, (%rsi)
+; AVX-NEXT:    vinsertf128 $1, %xmm0, %ymm1, %ymm0
+; AVX-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX-NEXT:    vzeroupper
 ; AVX-NEXT:    retq
 ;
 ; AVX2-LABEL: concat_a_to_shuf_of_a:
 ; AVX2:       # %bb.0:
 ; AVX2-NEXT:    vmovaps (%rdi), %xmm0
-; AVX2-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX2-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX2-NEXT:    vmovaps %xmm1, (%rsi)
+; AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1]
+; AVX2-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX2-NEXT:    vzeroupper
 ; AVX2-NEXT:    retq
 ;
 ; AVX512F-LABEL: concat_a_to_shuf_of_a:
 ; AVX512F:       # %bb.0:
 ; AVX512F-NEXT:    vmovaps (%rdi), %xmm0
-; AVX512F-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512F-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX512F-NEXT:    vmovaps %xmm1, (%rsi)
+; AVX512F-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1]
+; AVX512F-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX512F-NEXT:    vzeroupper
 ; AVX512F-NEXT:    retq
 ;
 ; AVX512BW-LABEL: concat_a_to_shuf_of_a:
 ; AVX512BW:       # %bb.0:
 ; AVX512BW-NEXT:    vmovaps (%rdi), %xmm0
-; AVX512BW-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512BW-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX512BW-NEXT:    vmovaps %xmm1, (%rsi)
+; AVX512BW-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1]
+; AVX512BW-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX512BW-NEXT:    vzeroupper
 ; AVX512BW-NEXT:    retq
   %a = load <2 x i64>, ptr %a.ptr, align 64
   %shuffle = shufflevector <2 x i64> %a, <2 x i64> poison, <2 x i32> <i32 1, i32 0>
@@ -69,32 +70,33 @@ define void @concat_shuf_of_a_to_a(ptr %a.ptr, ptr %b.ptr, ptr %dst) {
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vmovaps (%rdi), %xmm0
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX-NEXT:    vmovaps %xmm0, (%rdx)
-; AVX-NEXT:    vmovaps %xmm1, 16(%rdx)
+; AVX-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm0
+; AVX-NEXT:    vmovaps %ymm0, (%rdx)
+; AVX-NEXT:    vzeroupper
 ; AVX-NEXT:    retq
 ;
 ; AVX2-LABEL: concat_shuf_of_a_to_a:
 ; AVX2:       # %bb.0:
 ; AVX2-NEXT:    vmovaps (%rdi), %xmm0
-; AVX2-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX2-NEXT:    vmovaps %xmm0, (%rdx)
-; AVX2-NEXT:    vmovaps %xmm1, 16(%rdx)
+; AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0]
+; AVX2-NEXT:    vmovaps %ymm0, (%rdx)
+; AVX2-NEXT:    vzeroupper
 ; AVX2-NEXT:    retq
 ;
 ; AVX512F-LABEL: concat_shuf_of_a_to_a:
 ; AVX512F:       # %bb.0:
 ; AVX512F-NEXT:    vmovaps (%rdi), %xmm0
-; AVX512F-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512F-NEXT:    vmovaps %xmm0, (%rdx)
-; AVX512F-NEXT:    vmovaps %xmm1, 16(%rdx)
+; AVX512F-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0]
+; AVX512F-NEXT:    vmovaps %ymm0, (%rdx)
+; AVX512F-NEXT:    vzeroupper
 ; AVX512F-NEXT:    retq
 ;
 ; AVX512BW-LABEL: concat_shuf_of_a_to_a:
 ; AVX512BW:       # %bb.0:
 ; AVX512BW-NEXT:    vmovaps (%rdi), %xmm0
-; AVX512BW-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512BW-NEXT:    vmovaps %xmm0, (%rdx)
-; AVX512BW-NEXT:    vmovaps %xmm1, 16(%rdx)
+; AVX512BW-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0]
+; AVX512BW-NEXT:    vmovaps %ymm0, (%rdx)
+; AVX512BW-NEXT:    vzeroupper
 ; AVX512BW-NEXT:    retq
   %a = load <2 x i64>, ptr %a.ptr, align 64
   %b = load <2 x i64>, ptr %b.ptr, align 64
@@ -567,29 +569,33 @@ define void @concat_shuf_of_a_to_itself(ptr %a.ptr, ptr %dst) {
 ; AVX-LABEL: concat_shuf_of_a_to_itself:
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX-NEXT:    vmovaps %xmm0, (%rsi)
+; AVX-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
+; AVX-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX-NEXT:    vzeroupper
 ; AVX-NEXT:    retq
 ;
 ; AVX2-LABEL: concat_shuf_of_a_to_itself:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX2-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX2-NEXT:    vmovaps %xmm0, (%rsi)
+; AVX2-NEXT:    vmovaps (%rdi), %xmm0
+; AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0]
+; AVX2-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX2-NEXT:    vzeroupper
 ; AVX2-NEXT:    retq
 ;
 ; AVX512F-LABEL: concat_shuf_of_a_to_itself:
 ; AVX512F:       # %bb.0:
-; AVX512F-NEXT:    vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX512F-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX512F-NEXT:    vmovaps %xmm0, (%rsi)
+; AVX512F-NEXT:    vmovaps (%rdi), %xmm0
+; AVX512F-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0]
+; AVX512F-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX512F-NEXT:    vzeroupper
 ; AVX512F-NEXT:    retq
 ;
 ; AVX512BW-LABEL: concat_shuf_of_a_to_itself:
 ; AVX512BW:       # %bb.0:
-; AVX512BW-NEXT:    vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX512BW-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX512BW-NEXT:    vmovaps %xmm0, (%rsi)
+; AVX512BW-NEXT:    vmovaps (%rdi), %xmm0
+; AVX512BW-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0]
+; AVX512BW-NEXT:    vmovaps %ymm0, (%rsi)
+; AVX512BW-NEXT:    vzeroupper
 ; AVX512BW-NEXT:    retq
   %a = load <2 x i64>, ptr %a.ptr, align 64
   %shuffle = shufflevector <2 x i64> %a, <2 x i64> poison, <2 x i32> <i32 1, i32 0>
@@ -613,19 +619,18 @@ define void @concat_aaa_to_shuf_of_a(ptr %a.ptr, ptr %dst) {
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX-NEXT:    vinsertf128 $1, %xmm0, %ymm1, %ymm1
 ; AVX-NEXT:    vmovaps %ymm0, 32(%rsi)
-; AVX-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX-NEXT:    vmovaps %xmm1, (%rsi)
+; AVX-NEXT:    vmovaps %ymm1, (%rsi)
 ; AVX-NEXT:    vzeroupper
 ; AVX-NEXT:    retq
 ;
 ; AVX2-LABEL: concat_aaa_to_shuf_of_a:
 ; AVX2:       # %bb.0:
 ; AVX2-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
-; AVX2-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX2-NEXT:    vpermpd {{.*#+}} ymm1 = ymm0[1,0,0,1]
 ; AVX2-NEXT:    vmovaps %ymm0, 32(%rsi)
-; AVX2-NEXT:    vmovaps %xmm0, 16(%rsi)
-; AVX2-NEXT:    vmovaps %xmm1, (%rsi)
+; AVX2-NEXT:    vmovaps %ymm1, (%rsi)
 ; AVX2-NEXT:    vzeroupper
 ; AVX2-NEXT:    retq
 ;
@@ -671,19 +676,18 @@ define void @concat_shuf_of_a_to_aaa(ptr %a.ptr, ptr %dst) {
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm1
 ; AVX-NEXT:    vmovaps %ymm0, (%rsi)
-; AVX-NEXT:    vmovaps %xmm0, 32(%rsi)
-; AVX-NEXT:    vmovaps %xmm1, 48(%rsi)
+; AVX-NEXT:    vmovaps %ymm1, 32(%rsi)
 ; AVX-NEXT:    vzeroupper
 ; AVX-NEXT:    retq
 ;
 ; AVX2-LABEL: concat_shuf_of_a_to_aaa:
 ; AVX2:       # %bb.0:
 ; AVX2-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
-; AVX2-NEXT:    vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX2-NEXT:    vpermpd {{.*#+}} ymm1 = ymm0[0,1,1,0]
 ; AVX2-NEXT:    vmovaps %ymm0, (%rsi)
-; AVX2-NEXT:    vmovaps %xmm0, 32(%rsi)
-; AVX2-NEXT:    vmovaps %xmm1, 48(%rsi)
+; AVX2-NEXT:    vmovaps %ymm1, 32(%rsi)
 ; AVX2-NEXT:    vzeroupper
 ; AVX2-NEXT:    retq
 ;


        


More information about the llvm-commits mailing list