[llvm] [WIP][VectorCombine] Fold "shuffle (binop (shuffle, shuffle)), undef" --> "binop (shuffle), (shuffle)" (PR #114101)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 30 04:52:52 PDT 2024
https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/114101
>From daccb79bcd68ca7dac38d3d1d4242159317fc208 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Tue, 29 Oct 2024 17:36:00 +0000
Subject: [PATCH] [WIP][VectorCombine] Fold "shuffle (binop (shuffle,
shuffle)), undef" --> "binop (shuffle), (shuffle)"
Add foldPermuteOfBinops - to fold a permute (single source shuffle) through a binary op that is being fed by other shuffles.
WIP - still need to add additional test coverage.
Fixes #94546
---
.../Transforms/Vectorize/VectorCombine.cpp | 89 +++++++++++++++++++
.../X86/horiz-math-inseltpoison.ll | 7 +-
.../PhaseOrdering/X86/horiz-math.ll | 7 +-
.../Transforms/PhaseOrdering/X86/pr50392.ll | 7 +-
4 files changed, 98 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 58145c7e3c5913..c7ac2f3046a94d 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -112,6 +112,7 @@ class VectorCombine {
bool foldExtractedCmps(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
+ bool foldPermuteOfBinops(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
@@ -1400,6 +1401,93 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}
+/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
+/// --> "binop (shuffle), (shuffle)".
+bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
+ BinaryOperator *BinOp;
+ ArrayRef<int> OuterMask;
+ if (!match(&I,
+ m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask))))
+ return false;
+
+ // Don't introduce poison into div/rem.
+ if (llvm::is_contained(OuterMask, PoisonMaskElem) && BinOp->isIntDivRem())
+ return false;
+
+ Value *Op00, *Op01;
+ ArrayRef<int> Mask0;
+ if (!match(BinOp->getOperand(0),
+ m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)))))
+ return false;
+
+ Value *Op10, *Op11;
+ ArrayRef<int> Mask1;
+ if (!match(BinOp->getOperand(1),
+ m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)))))
+ return false;
+
+ Instruction::BinaryOps Opcode = BinOp->getOpcode();
+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
+ auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
+ auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
+ auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
+ if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
+ return false;
+
+ unsigned NumSrcElts = BinOpTy->getNumElements();
+
+ // Don't accept shuffles that reference the second (undef/poison) operand.
+ if (any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
+ return false;
+
+ // Merge outer / inner shuffles.
+ SmallVector<int> NewMask0, NewMask1;
+ for (int M : OuterMask) {
+ NewMask0.push_back(M >= 0 ? Mask0[M] : -1);
+ NewMask1.push_back(M >= 0 ? Mask1[M] : -1);
+ }
+
+ // Try to merge shuffles across the binop if the new shuffles are not costly.
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+ InstructionCost OldCost =
+ TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
+ OuterMask, CostKind, 0, nullptr, {BinOp}, &I) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
+ CostKind, 0, nullptr, {Op00, Op01},
+ cast<Instruction>(BinOp->getOperand(0))) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
+ CostKind, 0, nullptr, {Op10, Op11},
+ cast<Instruction>(BinOp->getOperand(1)));
+
+ InstructionCost NewCost =
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
+ CostKind, 0, nullptr, {Op00, Op01}) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
+ CostKind, 0, nullptr, {Op10, Op11}) +
+ TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
+
+ LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
+ << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
+ << "\n");
+ if (NewCost >= OldCost)
+ return false;
+
+ Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0);
+ Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1);
+ Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
+
+ // Intersect flags from the old binops.
+ if (auto *NewInst = dyn_cast<Instruction>(NewBO))
+ NewInst->copyIRFlags(BinOp);
+
+ Worklist.pushValue(Shuf0);
+ Worklist.pushValue(Shuf1);
+ replaceValue(I, *NewBO);
+ return true;
+}
+
/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
BinaryOperator *B0, *B1;
@@ -2736,6 +2824,7 @@ bool VectorCombine::run() {
MadeChange |= foldInsExtFNeg(I);
break;
case Instruction::ShuffleVector:
+ MadeChange |= foldPermuteOfBinops(I);
MadeChange |= foldShuffleOfBinops(I);
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldShuffleOfShuffles(I);
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/horiz-math-inseltpoison.ll b/llvm/test/Transforms/PhaseOrdering/X86/horiz-math-inseltpoison.ll
index 1d1c9d1f1d18c3..324503a30783d1 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/horiz-math-inseltpoison.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/horiz-math-inseltpoison.ll
@@ -108,11 +108,10 @@ define <8 x float> @hadd_reverse_v8f32(<8 x float> %a, <8 x float> %b) #0 {
define <8 x float> @reverse_hadd_v8f32(<8 x float> %a, <8 x float> %b) #0 {
; CHECK-LABEL: @reverse_hadd_v8f32(
-; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 0, i32 2, i32 8, i32 10, i32 4, i32 6, i32 12, i32 14>
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 1, i32 3, i32 9, i32 11, i32 5, i32 7, i32 13, i32 15>
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 14, i32 12, i32 6, i32 4, i32 10, i32 8, i32 2, i32 0>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 15, i32 13, i32 7, i32 5, i32 11, i32 9, i32 3, i32 1>
; CHECK-NEXT: [[TMP3:%.*]] = fadd <8 x float> [[TMP1]], [[TMP2]]
-; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <8 x float> [[TMP3]], <8 x float> poison, <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT: ret <8 x float> [[SHUFFLE]]
+; CHECK-NEXT: ret <8 x float> [[TMP3]]
;
%vecext = extractelement <8 x float> %a, i32 0
%vecext1 = extractelement <8 x float> %a, i32 1
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/horiz-math.ll b/llvm/test/Transforms/PhaseOrdering/X86/horiz-math.ll
index 4f8f04ec42497b..9d3b69218313e8 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/horiz-math.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/horiz-math.ll
@@ -108,11 +108,10 @@ define <8 x float> @hadd_reverse_v8f32(<8 x float> %a, <8 x float> %b) #0 {
define <8 x float> @reverse_hadd_v8f32(<8 x float> %a, <8 x float> %b) #0 {
; CHECK-LABEL: @reverse_hadd_v8f32(
-; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 0, i32 2, i32 8, i32 10, i32 4, i32 6, i32 12, i32 14>
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 1, i32 3, i32 9, i32 11, i32 5, i32 7, i32 13, i32 15>
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 14, i32 12, i32 6, i32 4, i32 10, i32 8, i32 2, i32 0>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 15, i32 13, i32 7, i32 5, i32 11, i32 9, i32 3, i32 1>
; CHECK-NEXT: [[TMP3:%.*]] = fadd <8 x float> [[TMP1]], [[TMP2]]
-; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <8 x float> [[TMP3]], <8 x float> poison, <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT: ret <8 x float> [[SHUFFLE]]
+; CHECK-NEXT: ret <8 x float> [[TMP3]]
;
%vecext = extractelement <8 x float> %a, i32 0
%vecext1 = extractelement <8 x float> %a, i32 1
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/pr50392.ll b/llvm/test/Transforms/PhaseOrdering/X86/pr50392.ll
index 4a024cc4c0309c..53d4b1ad96cb82 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/pr50392.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/pr50392.ll
@@ -32,10 +32,9 @@ define <4 x double> @PR50392(<4 x double> %a, <4 x double> %b) {
; AVX1-NEXT: ret <4 x double> [[SHUFFLE]]
;
; AVX2-LABEL: @PR50392(
-; AVX2-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[A:%.*]], <4 x double> [[B:%.*]], <2 x i32> <i32 0, i32 4>
-; AVX2-NEXT: [[TMP2:%.*]] = shufflevector <4 x double> [[A]], <4 x double> [[B]], <2 x i32> <i32 1, i32 5>
-; AVX2-NEXT: [[TMP3:%.*]] = fadd <2 x double> [[TMP1]], [[TMP2]]
-; AVX2-NEXT: [[TMP4:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> poison, <4 x i32> <i32 0, i32 poison, i32 1, i32 poison>
+; AVX2-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[A:%.*]], <4 x double> [[B:%.*]], <4 x i32> <i32 0, i32 poison, i32 4, i32 poison>
+; AVX2-NEXT: [[TMP2:%.*]] = shufflevector <4 x double> [[A]], <4 x double> [[B]], <4 x i32> <i32 1, i32 poison, i32 5, i32 poison>
+; AVX2-NEXT: [[TMP4:%.*]] = fadd <4 x double> [[TMP1]], [[TMP2]]
; AVX2-NEXT: [[SHIFT:%.*]] = shufflevector <4 x double> [[B]], <4 x double> poison, <4 x i32> <i32 poison, i32 poison, i32 3, i32 poison>
; AVX2-NEXT: [[TMP5:%.*]] = fadd <4 x double> [[B]], [[SHIFT]]
; AVX2-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x double> [[TMP4]], <4 x double> [[TMP5]], <4 x i32> <i32 0, i32 poison, i32 2, i32 6>
More information about the llvm-commits
mailing list