[llvm] [VectorCombine] foldShuffleOfShuffles - fold "shuffle (shuffle x, undef), (shuffle y, undef)" -> "shuffle x, y" (PR #88743)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 16 07:27:30 PDT 2024
================
@@ -1547,6 +1548,73 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
return true;
}
+/// Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
+/// into "shuffle x, y".
+bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
+ Value *V0, *V1;
+ ArrayRef<int> OuterMask, InnerMask0, InnerMask1;
+ if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_Undef(),
+ m_Mask(InnerMask0))),
+ m_OneUse(m_Shuffle(m_Value(V1), m_Undef(),
+ m_Mask(InnerMask1))),
+ m_Mask(OuterMask))))
+ return false;
+
+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
+ auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
+ auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
+ if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
+ V0->getType() != V1->getType())
+ return false;
+
+ unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
+ unsigned NumImmElts = ShuffleImmTy->getNumElements();
+
+ SmallVector<int, 16> NewMask(OuterMask.begin(), OuterMask.end());
+ for (int &M : NewMask) {
+ if (0 <= M && M < (int)NumImmElts)
+ M = InnerMask0[M];
+ else if ((int)NumImmElts <= M)
+ M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
+ }
+
+ // Have we folded to an Identity shuffle?
+ if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
+ replaceValue(I, *V0);
+ return true;
+ }
+
+ // Try to merge the shuffles if the new shuffle is not costly.
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+ InstructionCost OldCost =
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
+ InnerMask0, CostKind) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
+ InnerMask1, CostKind) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
+ OuterMask, CostKind, 0, nullptr, std::nullopt, &I);
+
+ InstructionCost NewCost = TTI.getShuffleCost(
+ TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, NewMask, CostKind);
+
+ LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
+ << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
+ << "\n");
+ if (NewCost > OldCost)
+ return false;
+
+ // Clear unused sources to undef.
+ if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; }))
+ V0 = UndefValue::get(ShuffleSrcTy);
+ if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; }))
+ V1 = UndefValue::get(ShuffleSrcTy);
----------------
alexey-bataev wrote:
POisonValue::get?
https://github.com/llvm/llvm-project/pull/88743
More information about the llvm-commits
mailing list