[llvm] [VectorCombine] foldShuffleOfShuffles - fold "shuffle (shuffle x, undef), (shuffle y, undef)" -> "shuffle x, y" (PR #88743)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 22 05:01:25 PDT 2024


================
@@ -1552,6 +1553,86 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
   return true;
 }
 
+/// Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
+/// into "shuffle x, y".
+bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
+  Value *V0, *V1;
+  UndefValue *U0, *U1;
+  ArrayRef<int> OuterMask, InnerMask0, InnerMask1;
+  if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_UndefValue(U0),
+                                              m_Mask(InnerMask0))),
+                           m_OneUse(m_Shuffle(m_Value(V1), m_UndefValue(U1),
+                                              m_Mask(InnerMask1))),
+                           m_Mask(OuterMask))))
+    return false;
+
+  auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
+  auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
+  auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
+  if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
+      V0->getType() != V1->getType())
+    return false;
+
+  unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
+  unsigned NumImmElts = ShuffleImmTy->getNumElements();
+
+  // Bail if either inner masks reference a RHS undef arg.
+  if ((!isa<PoisonValue>(U0) &&
+       any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) ||
+      (!isa<PoisonValue>(U1) &&
+       any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; })))
+    return false;
+
+  // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
+  SmallVector<int, 16> NewMask(OuterMask.begin(), OuterMask.end());
+  for (int &M : NewMask) {
+    if (0 <= M && M < (int)NumImmElts) {
+      M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
+    } else if (M >= (int)NumImmElts) {
+      if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts)
+        M = PoisonMaskElem;
+      else
+        M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
+    }
+  }
+
+  // Have we folded to an Identity shuffle?
+  if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
+    replaceValue(I, *V0);
+    return true;
+  }
+
+  // Try to merge the shuffles if the new shuffle is not costly.
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+  InstructionCost OldCost =
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
+                         InnerMask0, CostKind) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
+                         InnerMask1, CostKind) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
+                         OuterMask, CostKind, 0, nullptr, std::nullopt, &I);
----------------
RKSimon wrote:

All getShuffleCost calls have the correct shuffle masks - what am I missing?

https://github.com/llvm/llvm-project/pull/88743


More information about the llvm-commits mailing list