[llvm] 619d7dc - [DAGCombiner] recognize shuffle (shuffle X, Mask0), Mask --> splat X

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 1 06:14:56 PST 2020


Author: Sanjay Patel
Date: 2020-03-01T09:10:25-05:00
New Revision: 619d7dc39a1bb875cca77fe06c7dc670c380c861

URL: https://github.com/llvm/llvm-project/commit/619d7dc39a1bb875cca77fe06c7dc670c380c861
DIFF: https://github.com/llvm/llvm-project/commit/619d7dc39a1bb875cca77fe06c7dc670c380c861.diff

LOG: [DAGCombiner] recognize shuffle (shuffle X, Mask0), Mask --> splat X

We get the simple cases of this via demanded elements and other folds,
but that doesn't work if the values have >1 use, so add a dedicated
match for the pattern.

We already have this transform in IR, but it doesn't help the
motivating x86 tests (based on PR42024) because the shuffles don't
exist until after legalization and other combines have happened.
The AArch64 test shows a minimal IR example of the problem.

Differential Revision: https://reviews.llvm.org/D75348

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/arm64-dup.ll
    llvm/test/CodeGen/X86/vector-reduce-mul.ll
    llvm/test/CodeGen/X86/x86-interleaved-access.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index f95519b8eba5..f5205687ae00 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -30,6 +30,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/DAGCombine.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
@@ -19259,6 +19260,56 @@ static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
                               NewMask);
 }
 
+/// Combine shuffle of shuffle of the form:
+/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
+static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
+                                     SelectionDAG &DAG) {
+  if (!OuterShuf->getOperand(1).isUndef())
+    return SDValue();
+  auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
+  if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
+    return SDValue();
+
+  ArrayRef<int> OuterMask = OuterShuf->getMask();
+  ArrayRef<int> InnerMask = InnerShuf->getMask();
+  unsigned NumElts = OuterMask.size();
+  assert(NumElts == InnerMask.size() && "Mask length mismatch");
+  SmallVector<int, 32> CombinedMask(NumElts, -1);
+  int SplatIndex = -1;
+  for (unsigned i = 0; i != NumElts; ++i) {
+    // Undef lanes remain undef.
+    int OuterMaskElt = OuterMask[i];
+    if (OuterMaskElt == -1)
+      continue;
+
+    // Peek through the shuffle masks to get the underlying source element.
+    int InnerMaskElt = InnerMask[OuterMaskElt];
+    if (InnerMaskElt == -1)
+      continue;
+
+    // Initialize the splatted element.
+    if (SplatIndex == -1)
+      SplatIndex = InnerMaskElt;
+
+    // Non-matching index - this is not a splat.
+    if (SplatIndex != InnerMaskElt)
+      return SDValue();
+
+    CombinedMask[i] = InnerMaskElt;
+  }
+  assert(all_of(CombinedMask, [](int M) { return M == -1; }) ||
+         getSplatIndex(CombinedMask) != -1 && "Expected a splat mask");
+
+  // TODO: The transform may be a win even if the mask is not legal.
+  EVT VT = OuterShuf->getValueType(0);
+  assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
+  if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
+    return SDValue();
+
+  return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
+                              InnerShuf->getOperand(1), CombinedMask);
+}
+
 /// If the shuffle mask is taking exactly one element from the first vector
 /// operand and passing through all other elements from the second vector
 /// operand, return the index of the mask element that is choosing an element
@@ -19417,6 +19468,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
     return V;
 
+  if (SDValue V = formSplatFromShuffles(SVN, DAG))
+    return V;
+
   // If it is a splat, check if the argument vector is another splat or a
   // build_vector.
   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {

diff  --git a/llvm/test/CodeGen/AArch64/arm64-dup.ll b/llvm/test/CodeGen/AArch64/arm64-dup.ll
index f0104a9ec9ac..c635383b7e67 100644
--- a/llvm/test/CodeGen/AArch64/arm64-dup.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-dup.ll
@@ -449,8 +449,6 @@ define void @disguised_dup(<4 x float> %x, <4 x float>* %p1, <4 x float>* %p2) {
 ; CHECK-NEXT:    dup.4s v1, v0[0]
 ; CHECK-NEXT:    ext.16b v0, v0, v0, #12
 ; CHECK-NEXT:    ext.16b v0, v0, v1, #8
-; CHECK-NEXT:    zip2.4s v1, v0, v0
-; CHECK-NEXT:    ext.16b v1, v0, v1, #12
 ; CHECK-NEXT:    str q0, [x0]
 ; CHECK-NEXT:    str q1, [x1]
 ; CHECK-NEXT:    ret

diff  --git a/llvm/test/CodeGen/X86/vector-reduce-mul.ll b/llvm/test/CodeGen/X86/vector-reduce-mul.ll
index 2da46c0f5d33..ad229da8241c 100644
--- a/llvm/test/CodeGen/X86/vector-reduce-mul.ll
+++ b/llvm/test/CodeGen/X86/vector-reduce-mul.ll
@@ -811,7 +811,7 @@ define i32 @test_v4i32(<4 x i32> %a0) {
 ; SSE2-LABEL: test_v4i32:
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
-; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm0[3,3,1,1]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm2 = xmm0[3,1,2,3]
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3]
 ; SSE2-NEXT:    pmuludq %xmm2, %xmm3
 ; SSE2-NEXT:    pmuludq %xmm0, %xmm1
@@ -858,7 +858,7 @@ define i32 @test_v8i32(<8 x i32> %a0) {
 ; SSE2-NEXT:    pmuludq %xmm1, %xmm0
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
 ; SSE2-NEXT:    pmuludq %xmm0, %xmm1
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm3[2,2,0,0]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm3[2,0,2,2]
 ; SSE2-NEXT:    pmuludq %xmm3, %xmm0
 ; SSE2-NEXT:    pmuludq %xmm1, %xmm0
 ; SSE2-NEXT:    movd %xmm0, %eax
@@ -928,7 +928,7 @@ define i32 @test_v16i32(<16 x i32> %a0) {
 ; SSE2-NEXT:    pmuludq %xmm1, %xmm2
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
 ; SSE2-NEXT:    pmuludq %xmm0, %xmm1
-; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,2,0,0]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,0,2,2]
 ; SSE2-NEXT:    pmuludq %xmm2, %xmm0
 ; SSE2-NEXT:    pmuludq %xmm1, %xmm0
 ; SSE2-NEXT:    movd %xmm0, %eax
@@ -1018,7 +1018,7 @@ define i32 @test_v32i32(<32 x i32> %a0) {
 ; SSE2-NEXT:    pmuludq %xmm0, %xmm1
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,0,1]
 ; SSE2-NEXT:    pmuludq %xmm1, %xmm0
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm11[2,2,0,0]
+; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm11[2,0,2,2]
 ; SSE2-NEXT:    pmuludq %xmm11, %xmm1
 ; SSE2-NEXT:    pmuludq %xmm0, %xmm1
 ; SSE2-NEXT:    movd %xmm1, %eax

diff  --git a/llvm/test/CodeGen/X86/x86-interleaved-access.ll b/llvm/test/CodeGen/X86/x86-interleaved-access.ll
index 58cef0725c0d..e557a9d9e162 100644
--- a/llvm/test/CodeGen/X86/x86-interleaved-access.ll
+++ b/llvm/test/CodeGen/X86/x86-interleaved-access.ll
@@ -1826,45 +1826,40 @@ define void @splat4_v8i32_load_store(<8 x i32>* %s, <32 x i32>* %d) {
 define void @splat4_v4f64_load_store(<4 x double>* %s, <16 x double>* %d) {
 ; AVX1-LABEL: splat4_v4f64_load_store:
 ; AVX1:       # %bb.0:
-; AVX1-NEXT:    vmovupd (%rdi), %ymm0
-; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm1 = ymm0[0,1,0,1]
-; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm0 = ymm0[2,3,2,3]
-; AVX1-NEXT:    vmovddup {{.*#+}} ymm2 = ymm1[0,0,2,2]
-; AVX1-NEXT:    vmovddup {{.*#+}} ymm3 = ymm0[0,0,2,2]
-; AVX1-NEXT:    vpermilpd {{.*#+}} ymm1 = ymm1[1,1,3,3]
-; AVX1-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,1,3,3]
-; AVX1-NEXT:    vmovupd %ymm0, 96(%rsi)
-; AVX1-NEXT:    vmovupd %ymm3, 64(%rsi)
-; AVX1-NEXT:    vmovupd %ymm1, 32(%rsi)
-; AVX1-NEXT:    vmovupd %ymm2, (%rsi)
+; AVX1-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX1-NEXT:    vbroadcastsd 16(%rdi), %ymm1
+; AVX1-NEXT:    vbroadcastsd 8(%rdi), %ymm2
+; AVX1-NEXT:    vbroadcastsd 24(%rdi), %ymm3
+; AVX1-NEXT:    vmovups %ymm3, 96(%rsi)
+; AVX1-NEXT:    vmovups %ymm1, 64(%rsi)
+; AVX1-NEXT:    vmovups %ymm2, 32(%rsi)
+; AVX1-NEXT:    vmovups %ymm0, (%rsi)
 ; AVX1-NEXT:    vzeroupper
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: splat4_v4f64_load_store:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vmovups (%rdi), %ymm0
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm1
-; AVX2-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
-; AVX2-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
-; AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
-; AVX2-NEXT:    vmovups %ymm0, 96(%rsi)
-; AVX2-NEXT:    vmovups %ymm2, 64(%rsi)
-; AVX2-NEXT:    vmovups %ymm3, 32(%rsi)
-; AVX2-NEXT:    vmovups %ymm1, (%rsi)
+; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX2-NEXT:    vbroadcastsd 16(%rdi), %ymm1
+; AVX2-NEXT:    vbroadcastsd 8(%rdi), %ymm2
+; AVX2-NEXT:    vbroadcastsd 24(%rdi), %ymm3
+; AVX2-NEXT:    vmovups %ymm3, 96(%rsi)
+; AVX2-NEXT:    vmovups %ymm1, 64(%rsi)
+; AVX2-NEXT:    vmovups %ymm2, 32(%rsi)
+; AVX2-NEXT:    vmovups %ymm0, (%rsi)
 ; AVX2-NEXT:    vzeroupper
 ; AVX2-NEXT:    retq
 ;
 ; AVX512-LABEL: splat4_v4f64_load_store:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vmovups (%rdi), %ymm0
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm1
-; AVX512-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
-; AVX512-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
-; AVX512-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
+; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX512-NEXT:    vbroadcastsd 16(%rdi), %ymm1
+; AVX512-NEXT:    vbroadcastsd 8(%rdi), %ymm2
+; AVX512-NEXT:    vbroadcastsd 24(%rdi), %ymm3
+; AVX512-NEXT:    vinsertf64x4 $1, %ymm2, %zmm0, %zmm0
 ; AVX512-NEXT:    vinsertf64x4 $1, %ymm3, %zmm1, %zmm1
-; AVX512-NEXT:    vinsertf64x4 $1, %ymm0, %zmm2, %zmm0
-; AVX512-NEXT:    vmovups %zmm0, 64(%rsi)
-; AVX512-NEXT:    vmovups %zmm1, (%rsi)
+; AVX512-NEXT:    vmovups %zmm1, 64(%rsi)
+; AVX512-NEXT:    vmovups %zmm0, (%rsi)
 ; AVX512-NEXT:    vzeroupper
 ; AVX512-NEXT:    retq
   %x = load <4 x double>, <4 x double>* %s, align 8
@@ -1878,45 +1873,40 @@ define void @splat4_v4f64_load_store(<4 x double>* %s, <16 x double>* %d) {
 define void @splat4_v4i64_load_store(<4 x i64>* %s, <16 x i64>* %d) {
 ; AVX1-LABEL: splat4_v4i64_load_store:
 ; AVX1:       # %bb.0:
-; AVX1-NEXT:    vmovupd (%rdi), %ymm0
-; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm1 = ymm0[0,1,0,1]
-; AVX1-NEXT:    vperm2f128 {{.*#+}} ymm0 = ymm0[2,3,2,3]
-; AVX1-NEXT:    vmovddup {{.*#+}} ymm2 = ymm1[0,0,2,2]
-; AVX1-NEXT:    vmovddup {{.*#+}} ymm3 = ymm0[0,0,2,2]
-; AVX1-NEXT:    vpermilpd {{.*#+}} ymm1 = ymm1[1,1,3,3]
-; AVX1-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,1,3,3]
-; AVX1-NEXT:    vmovupd %ymm0, 96(%rsi)
-; AVX1-NEXT:    vmovupd %ymm3, 64(%rsi)
-; AVX1-NEXT:    vmovupd %ymm1, 32(%rsi)
-; AVX1-NEXT:    vmovupd %ymm2, (%rsi)
+; AVX1-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX1-NEXT:    vbroadcastsd 16(%rdi), %ymm1
+; AVX1-NEXT:    vbroadcastsd 8(%rdi), %ymm2
+; AVX1-NEXT:    vbroadcastsd 24(%rdi), %ymm3
+; AVX1-NEXT:    vmovups %ymm3, 96(%rsi)
+; AVX1-NEXT:    vmovups %ymm1, 64(%rsi)
+; AVX1-NEXT:    vmovups %ymm2, 32(%rsi)
+; AVX1-NEXT:    vmovups %ymm0, (%rsi)
 ; AVX1-NEXT:    vzeroupper
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: splat4_v4i64_load_store:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vmovups (%rdi), %ymm0
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm1
-; AVX2-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
-; AVX2-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
-; AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
-; AVX2-NEXT:    vmovups %ymm0, 96(%rsi)
-; AVX2-NEXT:    vmovups %ymm2, 64(%rsi)
-; AVX2-NEXT:    vmovups %ymm3, 32(%rsi)
-; AVX2-NEXT:    vmovups %ymm1, (%rsi)
+; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX2-NEXT:    vbroadcastsd 16(%rdi), %ymm1
+; AVX2-NEXT:    vbroadcastsd 8(%rdi), %ymm2
+; AVX2-NEXT:    vbroadcastsd 24(%rdi), %ymm3
+; AVX2-NEXT:    vmovups %ymm3, 96(%rsi)
+; AVX2-NEXT:    vmovups %ymm1, 64(%rsi)
+; AVX2-NEXT:    vmovups %ymm2, 32(%rsi)
+; AVX2-NEXT:    vmovups %ymm0, (%rsi)
 ; AVX2-NEXT:    vzeroupper
 ; AVX2-NEXT:    retq
 ;
 ; AVX512-LABEL: splat4_v4i64_load_store:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vmovups (%rdi), %ymm0
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm1
-; AVX512-NEXT:    vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
-; AVX512-NEXT:    vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
-; AVX512-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
+; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX512-NEXT:    vbroadcastsd 16(%rdi), %ymm1
+; AVX512-NEXT:    vbroadcastsd 8(%rdi), %ymm2
+; AVX512-NEXT:    vbroadcastsd 24(%rdi), %ymm3
+; AVX512-NEXT:    vinsertf64x4 $1, %ymm2, %zmm0, %zmm0
 ; AVX512-NEXT:    vinsertf64x4 $1, %ymm3, %zmm1, %zmm1
-; AVX512-NEXT:    vinsertf64x4 $1, %ymm0, %zmm2, %zmm0
-; AVX512-NEXT:    vmovups %zmm0, 64(%rsi)
-; AVX512-NEXT:    vmovups %zmm1, (%rsi)
+; AVX512-NEXT:    vmovups %zmm1, 64(%rsi)
+; AVX512-NEXT:    vmovups %zmm0, (%rsi)
 ; AVX512-NEXT:    vzeroupper
 ; AVX512-NEXT:    retq
   %x = load <4 x i64>, <4 x i64>* %s, align 8


        


More information about the llvm-commits mailing list