[llvm] 1337821 - [DAGCombiner][X86] Fold a CONCAT_VECTORS of SHUFFLE_VECTOR and it's operand into wider SHUFFLE_VECTOR
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 1 12:19:32 PST 2023
Author: Roman Lebedev
Date: 2023-01-01T23:18:42+03:00
New Revision: 1337821f11902e219fe7720494879799219f2fc5
URL: https://github.com/llvm/llvm-project/commit/1337821f11902e219fe7720494879799219f2fc5
DIFF: https://github.com/llvm/llvm-project/commit/1337821f11902e219fe7720494879799219f2fc5.diff
LOG: [DAGCombiner][X86] Fold a CONCAT_VECTORS of SHUFFLE_VECTOR and it's operand into wider SHUFFLE_VECTOR
This was showing as a source of *many* regressions
with more aggressive ZERO_EXTEND_VECTOR_INREG recognition.
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ee2918e419404..9235042c54a4a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -21891,6 +21891,109 @@ static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(CastOpcode, DL, VT, NewConcat);
}
+// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
+// the operands is a SHUFFLE_VECTOR, and all other operands are also operands
+// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
+static SDValue combineConcatVectorOfShuffleAndItsOperands(
+ SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
+ bool LegalOperations) {
+ EVT VT = N->getValueType(0);
+ EVT OpVT = N->getOperand(0).getValueType();
+ if (VT.isScalableVector())
+ return SDValue();
+
+ // For now, only allow simple 2-operand concatenations.
+ if (N->getNumOperands() != 2)
+ return SDValue();
+
+ // Don't create illegal types/shuffles when not allowed to.
+ if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
+ (LegalOperations &&
+ !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
+ return SDValue();
+
+ // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
+ // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
+ // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
+ // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
+ // (4) and for now, the SHUFFLE_VECTOR must be unary.
+ ShuffleVectorSDNode *SVN = nullptr;
+ for (SDValue Op : N->ops()) {
+ if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
+ CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
+ all_of(N->ops(), [CurSVN](SDValue Op) {
+ // FIXME: can we allow UNDEF operands?
+ return !Op.isUndef() &&
+ (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
+ })) {
+ SVN = CurSVN;
+ break;
+ }
+ }
+ if (!SVN)
+ return SDValue();
+
+ // We are going to pad the shuffle operands, so any indice, that was picking
+ // from the second operand, must be adjusted.
+ SmallVector<int, 16> AdjustedMask;
+ AdjustedMask.reserve(SVN->getMask().size());
+ assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
+ append_range(AdjustedMask, SVN->getMask());
+
+ // Identity masks for the operands of the (padded) shuffle.
+ SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
+ MutableArrayRef<int> FirstShufOpIdentityMask =
+ MutableArrayRef<int>(IdentityMask)
+ .take_front(OpVT.getVectorNumElements());
+ MutableArrayRef<int> SecondShufOpIdentityMask =
+ MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
+ std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
+ std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
+ VT.getVectorNumElements());
+
+ // New combined shuffle mask.
+ SmallVector<int, 32> Mask;
+ Mask.reserve(VT.getVectorNumElements());
+ for (SDValue Op : N->ops()) {
+ assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
+ if (Op.getNode() == SVN) {
+ append_range(Mask, AdjustedMask);
+ continue;
+ }
+ if (Op == SVN->getOperand(0)) {
+ append_range(Mask, FirstShufOpIdentityMask);
+ continue;
+ }
+ if (Op == SVN->getOperand(1)) {
+ append_range(Mask, SecondShufOpIdentityMask);
+ continue;
+ }
+ llvm_unreachable("Unexpected operand!");
+ }
+
+ // Don't create illegal shuffle masks.
+ if (!TLI.isShuffleMaskLegal(Mask, VT))
+ return SDValue();
+
+ // Pad the shuffle operands with UNDEF.
+ SDLoc dl(N);
+ std::array<SDValue, 2> ShufOps;
+ for (auto I : zip(SVN->ops(), ShufOps)) {
+ SDValue ShufOp = std::get<0>(I);
+ SDValue &NewShufOp = std::get<1>(I);
+ if (ShufOp.isUndef())
+ NewShufOp = DAG.getUNDEF(VT);
+ else {
+ SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
+ DAG.getUNDEF(OpVT));
+ ShufOpParts[0] = ShufOp;
+ NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
+ }
+ }
+ // Finally, create the new wide shuffle.
+ return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
+}
+
SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
// If we only have one input vector, we don't need to do any concatenation.
if (N->getNumOperands() == 1)
@@ -22026,6 +22129,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
if (SDValue V = combineConcatVectorOfCasts(N, DAG))
return V;
+ if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
+ N, DAG, TLI, LegalTypes, LegalOperations))
+ return V;
+
// Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
// nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
// operands and look for a CONCAT operations that place the incoming vectors
diff --git a/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll b/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll
index 4ffe97e6de236..52fc059cc6818 100644
--- a/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll
+++ b/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll
@@ -23,32 +23,33 @@ define void @concat_a_to_shuf_of_a(ptr %a.ptr, ptr %dst) {
; AVX: # %bb.0:
; AVX-NEXT: vmovaps (%rdi), %xmm0
; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX-NEXT: vmovaps %xmm1, (%rsi)
+; AVX-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0
+; AVX-NEXT: vmovaps %ymm0, (%rsi)
+; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
;
; AVX2-LABEL: concat_a_to_shuf_of_a:
; AVX2: # %bb.0:
; AVX2-NEXT: vmovaps (%rdi), %xmm0
-; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX2-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX2-NEXT: vmovaps %xmm1, (%rsi)
+; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1]
+; AVX2-NEXT: vmovaps %ymm0, (%rsi)
+; AVX2-NEXT: vzeroupper
; AVX2-NEXT: retq
;
; AVX512F-LABEL: concat_a_to_shuf_of_a:
; AVX512F: # %bb.0:
; AVX512F-NEXT: vmovaps (%rdi), %xmm0
-; AVX512F-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512F-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX512F-NEXT: vmovaps %xmm1, (%rsi)
+; AVX512F-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1]
+; AVX512F-NEXT: vmovaps %ymm0, (%rsi)
+; AVX512F-NEXT: vzeroupper
; AVX512F-NEXT: retq
;
; AVX512BW-LABEL: concat_a_to_shuf_of_a:
; AVX512BW: # %bb.0:
; AVX512BW-NEXT: vmovaps (%rdi), %xmm0
-; AVX512BW-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512BW-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX512BW-NEXT: vmovaps %xmm1, (%rsi)
+; AVX512BW-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1]
+; AVX512BW-NEXT: vmovaps %ymm0, (%rsi)
+; AVX512BW-NEXT: vzeroupper
; AVX512BW-NEXT: retq
%a = load <2 x i64>, ptr %a.ptr, align 64
%shuffle = shufflevector <2 x i64> %a, <2 x i64> poison, <2 x i32> <i32 1, i32 0>
@@ -69,32 +70,33 @@ define void @concat_shuf_of_a_to_a(ptr %a.ptr, ptr %b.ptr, ptr %dst) {
; AVX: # %bb.0:
; AVX-NEXT: vmovaps (%rdi), %xmm0
; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX-NEXT: vmovaps %xmm0, (%rdx)
-; AVX-NEXT: vmovaps %xmm1, 16(%rdx)
+; AVX-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
+; AVX-NEXT: vmovaps %ymm0, (%rdx)
+; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
;
; AVX2-LABEL: concat_shuf_of_a_to_a:
; AVX2: # %bb.0:
; AVX2-NEXT: vmovaps (%rdi), %xmm0
-; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX2-NEXT: vmovaps %xmm0, (%rdx)
-; AVX2-NEXT: vmovaps %xmm1, 16(%rdx)
+; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0]
+; AVX2-NEXT: vmovaps %ymm0, (%rdx)
+; AVX2-NEXT: vzeroupper
; AVX2-NEXT: retq
;
; AVX512F-LABEL: concat_shuf_of_a_to_a:
; AVX512F: # %bb.0:
; AVX512F-NEXT: vmovaps (%rdi), %xmm0
-; AVX512F-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512F-NEXT: vmovaps %xmm0, (%rdx)
-; AVX512F-NEXT: vmovaps %xmm1, 16(%rdx)
+; AVX512F-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0]
+; AVX512F-NEXT: vmovaps %ymm0, (%rdx)
+; AVX512F-NEXT: vzeroupper
; AVX512F-NEXT: retq
;
; AVX512BW-LABEL: concat_shuf_of_a_to_a:
; AVX512BW: # %bb.0:
; AVX512BW-NEXT: vmovaps (%rdi), %xmm0
-; AVX512BW-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; AVX512BW-NEXT: vmovaps %xmm0, (%rdx)
-; AVX512BW-NEXT: vmovaps %xmm1, 16(%rdx)
+; AVX512BW-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0]
+; AVX512BW-NEXT: vmovaps %ymm0, (%rdx)
+; AVX512BW-NEXT: vzeroupper
; AVX512BW-NEXT: retq
%a = load <2 x i64>, ptr %a.ptr, align 64
%b = load <2 x i64>, ptr %b.ptr, align 64
@@ -567,29 +569,33 @@ define void @concat_shuf_of_a_to_itself(ptr %a.ptr, ptr %dst) {
; AVX-LABEL: concat_shuf_of_a_to_itself:
; AVX: # %bb.0:
; AVX-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX-NEXT: vmovaps %xmm0, (%rsi)
+; AVX-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0
+; AVX-NEXT: vmovaps %ymm0, (%rsi)
+; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
;
; AVX2-LABEL: concat_shuf_of_a_to_itself:
; AVX2: # %bb.0:
-; AVX2-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX2-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX2-NEXT: vmovaps %xmm0, (%rsi)
+; AVX2-NEXT: vmovaps (%rdi), %xmm0
+; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0]
+; AVX2-NEXT: vmovaps %ymm0, (%rsi)
+; AVX2-NEXT: vzeroupper
; AVX2-NEXT: retq
;
; AVX512F-LABEL: concat_shuf_of_a_to_itself:
; AVX512F: # %bb.0:
-; AVX512F-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX512F-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX512F-NEXT: vmovaps %xmm0, (%rsi)
+; AVX512F-NEXT: vmovaps (%rdi), %xmm0
+; AVX512F-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0]
+; AVX512F-NEXT: vmovaps %ymm0, (%rsi)
+; AVX512F-NEXT: vzeroupper
; AVX512F-NEXT: retq
;
; AVX512BW-LABEL: concat_shuf_of_a_to_itself:
; AVX512BW: # %bb.0:
-; AVX512BW-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1]
-; AVX512BW-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX512BW-NEXT: vmovaps %xmm0, (%rsi)
+; AVX512BW-NEXT: vmovaps (%rdi), %xmm0
+; AVX512BW-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0]
+; AVX512BW-NEXT: vmovaps %ymm0, (%rsi)
+; AVX512BW-NEXT: vzeroupper
; AVX512BW-NEXT: retq
%a = load <2 x i64>, ptr %a.ptr, align 64
%shuffle = shufflevector <2 x i64> %a, <2 x i64> poison, <2 x i32> <i32 1, i32 0>
@@ -613,19 +619,18 @@ define void @concat_aaa_to_shuf_of_a(ptr %a.ptr, ptr %dst) {
; AVX: # %bb.0:
; AVX-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm1
; AVX-NEXT: vmovaps %ymm0, 32(%rsi)
-; AVX-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX-NEXT: vmovaps %xmm1, (%rsi)
+; AVX-NEXT: vmovaps %ymm1, (%rsi)
; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
;
; AVX2-LABEL: concat_aaa_to_shuf_of_a:
; AVX2: # %bb.0:
; AVX2-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
-; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX2-NEXT: vpermpd {{.*#+}} ymm1 = ymm0[1,0,0,1]
; AVX2-NEXT: vmovaps %ymm0, 32(%rsi)
-; AVX2-NEXT: vmovaps %xmm0, 16(%rsi)
-; AVX2-NEXT: vmovaps %xmm1, (%rsi)
+; AVX2-NEXT: vmovaps %ymm1, (%rsi)
; AVX2-NEXT: vzeroupper
; AVX2-NEXT: retq
;
@@ -671,19 +676,18 @@ define void @concat_shuf_of_a_to_aaa(ptr %a.ptr, ptr %dst) {
; AVX: # %bb.0:
; AVX-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm1
; AVX-NEXT: vmovaps %ymm0, (%rsi)
-; AVX-NEXT: vmovaps %xmm0, 32(%rsi)
-; AVX-NEXT: vmovaps %xmm1, 48(%rsi)
+; AVX-NEXT: vmovaps %ymm1, 32(%rsi)
; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
;
; AVX2-LABEL: concat_shuf_of_a_to_aaa:
; AVX2: # %bb.0:
; AVX2-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
-; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1]
+; AVX2-NEXT: vpermpd {{.*#+}} ymm1 = ymm0[0,1,1,0]
; AVX2-NEXT: vmovaps %ymm0, (%rsi)
-; AVX2-NEXT: vmovaps %xmm0, 32(%rsi)
-; AVX2-NEXT: vmovaps %xmm1, 48(%rsi)
+; AVX2-NEXT: vmovaps %ymm1, 32(%rsi)
; AVX2-NEXT: vzeroupper
; AVX2-NEXT: retq
;
More information about the llvm-commits
mailing list