[llvm] [VectorCombine] foldShuffleOfShuffles - fold "shuffle (shuffle x, undef), (shuffle y, undef)" -> "shuffle x, y" (PR #88743)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 15 08:00:57 PDT 2024


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/88743

Another step towards cleaning up shuffles that have been split, often across bitcasts between SSE intrinsic.

Strip shuffles entirely if we fold to an identity shuffle.

>From aed7024246cbe7ff031a763d9d3d6aa34f4b3472 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Mon, 15 Apr 2024 15:00:18 +0100
Subject: [PATCH] [VectorCombine] foldShuffleOfShuffles - fold "shuffle
 (shuffle x, undef), (shuffle y, undef)" -> "shuffle x, y"

Another step towards cleaning up shuffles that have been split, often across bitcasts between SSE intrinsic.

Strip shuffles entirely if we fold to an identity shuffle.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 71 ++++++++++++++++++-
 .../VectorCombine/AArch64/select-shuffle.ll   | 18 ++---
 .../AArch64/shuffletoidentity.ll              |  7 +-
 .../Transforms/VectorCombine/X86/pr67803.ll   |  5 +-
 4 files changed, 78 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index e0e2f50c89adad..bfc23f0b1fdf3f 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -113,6 +113,7 @@ class VectorCombine {
   bool scalarizeLoadExtract(Instruction &I);
   bool foldShuffleOfBinops(Instruction &I);
   bool foldShuffleOfCastops(Instruction &I);
+  bool foldShuffleOfShuffles(Instruction &I);
   bool foldShuffleFromReductions(Instruction &I);
   bool foldTruncFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
@@ -1547,7 +1548,74 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   return true;
 }
 
-/// Given a commutative reduction, the order of the input lanes does not alter
+/// Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
+/// into "shuffle x, y".
+bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
+  Value *V0, *V1;
+  ArrayRef<int> OuterMask, InnerMask0, InnerMask1;
+  if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_Undef(),
+                                              m_Mask(InnerMask0))),
+                           m_OneUse(m_Shuffle(m_Value(V1), m_Undef(),
+                                              m_Mask(InnerMask1))),
+                           m_Mask(OuterMask))))
+    return false;
+
+  auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
+  auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
+  auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
+  if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
+      V0->getType() != V1->getType())
+    return false;
+
+  unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
+  unsigned NumImmElts = ShuffleImmTy->getNumElements();
+
+  SmallVector<int, 16> NewMask(OuterMask.begin(), OuterMask.end());
+  for (int &M : NewMask) {
+    if (0 <= M && M < (int)NumImmElts)
+      M = InnerMask0[M];
+    else if ((int)NumImmElts <= M)
+      M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
+  }
+
+  // Have we folded to an Identity shuffle?
+  if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
+    replaceValue(I, *V0);
+    return true;
+  }
+
+  // Try to merge the shuffles if the new shuffle is not costly.
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+  InstructionCost OldCost =
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
+                         InnerMask0, CostKind) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
+                         InnerMask1, CostKind) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
+                         OuterMask, CostKind, 0, nullptr, std::nullopt, &I);
+
+  InstructionCost NewCost = TTI.getShuffleCost(
+      TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, NewMask, CostKind);
+
+  LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
+                    << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
+                    << "\n");
+  if (NewCost > OldCost)
+    return false;
+
+  // Clear unused sources to undef.
+  if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; }))
+    V0 = UndefValue::get(ShuffleSrcTy);
+  if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; }))
+    V1 = UndefValue::get(ShuffleSrcTy);
+
+  Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask);
+  replaceValue(I, *Shuf);
+  return true;
+}
+
+  /// Given a commutative reduction, the order of the input lanes does not alter
 /// the results. We can use this to remove certain shuffles feeding the
 /// reduction, removing the need to shuffle at all.
 bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
@@ -2102,6 +2170,7 @@ bool VectorCombine::run() {
       case Instruction::ShuffleVector:
         MadeChange |= foldShuffleOfBinops(I);
         MadeChange |= foldShuffleOfCastops(I);
+        MadeChange |= foldShuffleOfShuffles(I);
         MadeChange |= foldSelectShuffle(I);
         break;
       case Instruction::BitCast:
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll b/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll
index eae08790048394..b49f3c9f3eeb27 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -passes='vector-combine' -S %s | FileCheck %s
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
 
 target triple = "aarch64"
 
@@ -741,18 +741,14 @@ define i32 @full_reorder(ptr nocapture noundef readonly %pix1, i32 noundef %i_pi
 ; CHECK-NEXT:    [[TMP10:%.*]] = load <4 x i8>, ptr [[ARRAYIDX3_2]], align 1
 ; CHECK-NEXT:    [[TMP11:%.*]] = load <4 x i8>, ptr [[ARRAYIDX5_2]], align 1
 ; CHECK-NEXT:    [[TMP12:%.*]] = load <4 x i8>, ptr [[ADD_PTR_2]], align 1
-; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <4 x i8> [[TMP12]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i8> [[TMP8]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <16 x i8> [[TMP13]], <16 x i8> [[TMP14]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 16, i32 17, i32 18, i32 19, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <4 x i8> [[TMP12]], <4 x i8> [[TMP8]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP16:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP17:%.*]] = shufflevector <16 x i8> [[TMP15]], <16 x i8> [[TMP16]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 12, i32 13, i32 14, i32 15>
 ; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP19:%.*]] = shufflevector <16 x i8> [[TMP17]], <16 x i8> [[TMP18]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
 ; CHECK-NEXT:    [[TMP20:%.*]] = zext <16 x i8> [[TMP19]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP21:%.*]] = load <4 x i8>, ptr [[ADD_PTR64_2]], align 1
-; CHECK-NEXT:    [[TMP22:%.*]] = shufflevector <4 x i8> [[TMP21]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP23:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP24:%.*]] = shufflevector <16 x i8> [[TMP22]], <16 x i8> [[TMP23]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 16, i32 17, i32 18, i32 19, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[TMP24:%.*]] = shufflevector <4 x i8> [[TMP21]], <4 x i8> [[TMP9]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP25:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP26:%.*]] = shufflevector <16 x i8> [[TMP24]], <16 x i8> [[TMP25]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 12, i32 13, i32 14, i32 15>
 ; CHECK-NEXT:    [[TMP27:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
@@ -760,18 +756,14 @@ define i32 @full_reorder(ptr nocapture noundef readonly %pix1, i32 noundef %i_pi
 ; CHECK-NEXT:    [[TMP29:%.*]] = zext <16 x i8> [[TMP28]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP30:%.*]] = sub nsw <16 x i32> [[TMP20]], [[TMP29]]
 ; CHECK-NEXT:    [[TMP31:%.*]] = load <4 x i8>, ptr [[ARRAYIDX3_3]], align 1
-; CHECK-NEXT:    [[TMP32:%.*]] = shufflevector <4 x i8> [[TMP31]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP33:%.*]] = shufflevector <4 x i8> [[TMP10]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP34:%.*]] = shufflevector <16 x i8> [[TMP32]], <16 x i8> [[TMP33]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 16, i32 17, i32 18, i32 19, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[TMP34:%.*]] = shufflevector <4 x i8> [[TMP31]], <4 x i8> [[TMP10]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP35:%.*]] = shufflevector <4 x i8> [[TMP6]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP36:%.*]] = shufflevector <16 x i8> [[TMP34]], <16 x i8> [[TMP35]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 12, i32 13, i32 14, i32 15>
 ; CHECK-NEXT:    [[TMP37:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP38:%.*]] = shufflevector <16 x i8> [[TMP36]], <16 x i8> [[TMP37]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
 ; CHECK-NEXT:    [[TMP39:%.*]] = zext <16 x i8> [[TMP38]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP40:%.*]] = load <4 x i8>, ptr [[ARRAYIDX5_3]], align 1
-; CHECK-NEXT:    [[TMP41:%.*]] = shufflevector <4 x i8> [[TMP40]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP42:%.*]] = shufflevector <4 x i8> [[TMP11]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP43:%.*]] = shufflevector <16 x i8> [[TMP41]], <16 x i8> [[TMP42]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 16, i32 17, i32 18, i32 19, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[TMP43:%.*]] = shufflevector <4 x i8> [[TMP40]], <4 x i8> [[TMP11]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP44:%.*]] = shufflevector <4 x i8> [[TMP7]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP45:%.*]] = shufflevector <16 x i8> [[TMP43]], <16 x i8> [[TMP44]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 12, i32 13, i32 14, i32 15>
 ; CHECK-NEXT:    [[TMP46:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
index d96dfec849167d..cd78bea2f45a15 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
@@ -1,14 +1,11 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -passes='vector-combine' -S %s | FileCheck %s
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
 
 target triple = "aarch64"
 
 define <8 x i8> @trivial(<8 x i8> %a) {
 ; CHECK-LABEL: @trivial(
-; CHECK-NEXT:    [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[R:%.*]] = shufflevector <4 x i8> [[AT]], <4 x i8> [[AB]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    ret <8 x i8> [[R]]
+; CHECK-NEXT:    ret <8 x i8> [[R:%.*]]
 ;
   %ab = shufflevector <8 x i8> %a, <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
   %at = shufflevector <8 x i8> %a, <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
diff --git a/llvm/test/Transforms/VectorCombine/X86/pr67803.ll b/llvm/test/Transforms/VectorCombine/X86/pr67803.ll
index 69fd6f6a10e2a6..0277580d21fcb7 100644
--- a/llvm/test/Transforms/VectorCombine/X86/pr67803.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/pr67803.ll
@@ -6,10 +6,7 @@ define <4 x i64> @PR67803(<8 x i32> %x, <8 x i32> %y, <8 x float> %a, <8 x float
 ; CHECK-LABEL: @PR67803(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt <8 x i32> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[CMP_LO:%.*]] = shufflevector <8 x i1> [[CMP]], <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[CMP_HI:%.*]] = shufflevector <8 x i1> [[CMP]], <8 x i1> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <4 x i1> [[CMP_LO]], <4 x i1> [[CMP_HI]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i1> [[TMP0]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i1> [[CMP]] to <8 x i32>
 ; CHECK-NEXT:    [[CONCAT:%.*]] = bitcast <8 x i32> [[TMP1]] to <4 x i64>
 ; CHECK-NEXT:    [[MASK:%.*]] = bitcast <4 x i64> [[CONCAT]] to <8 x float>
 ; CHECK-NEXT:    [[SEL:%.*]] = tail call noundef <8 x float> @llvm.x86.avx.blendv.ps.256(<8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x float> [[MASK]])



More information about the llvm-commits mailing list