[llvm] [SLP][NFC]Extract a check for a SplitVectorize node, NFC (PR #134896)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 8 10:58:05 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Alexey Bataev (alexey-bataev)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/134896.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+130-109)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8d411f2cb203a..c7d0681b56e39 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -3125,6 +3125,18 @@ class BoUpSLP {
ArrayRef<Value *> VectorizedVals,
SmallPtrSetImpl<Value *> &CheckedExtracts);
+ /// Checks if it is legal and profitable to build SplitVectorize node for the
+ /// given \p VL.
+ /// \param Op1 first homogeneous scalars.
+ /// \param Op2 second homogeneous scalars.
+ /// \param ReorderIndices indices to reorder the scalars.
+ /// \returns true if the node was successfully built.
+ bool canBuildSplitNode(ArrayRef<Value *> VL,
+ const InstructionsState &LocalState,
+ SmallVectorImpl<Value *> &Op1,
+ SmallVectorImpl<Value *> &Op2,
+ OrdersType &ReorderIndices) const;
+
/// This is the recursive part of buildTree.
void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth,
const EdgeInfo &EI, unsigned InterleaveFactor = 0);
@@ -9163,6 +9175,117 @@ static bool tryToFindDuplicates(SmallVectorImpl<Value *> &VL,
return true;
}
+bool BoUpSLP::canBuildSplitNode(ArrayRef<Value *> VL,
+ const InstructionsState &LocalState,
+ SmallVectorImpl<Value *> &Op1,
+ SmallVectorImpl<Value *> &Op2,
+ OrdersType &ReorderIndices) const {
+ constexpr unsigned SmallNodeSize = 4;
+ if (VL.size() <= SmallNodeSize || TTI->preferAlternateOpcodeVectorization() ||
+ !SplitAlternateInstructions)
+ return false;
+
+ ReorderIndices.assign(VL.size(), VL.size());
+ SmallBitVector Op1Indices(VL.size());
+ for (auto [Idx, V] : enumerate(VL)) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I) {
+ Op1.push_back(V);
+ Op1Indices.set(Idx);
+ continue;
+ }
+ if ((LocalState.getAltOpcode() != LocalState.getOpcode() &&
+ I->getOpcode() == LocalState.getOpcode()) ||
+ (LocalState.getAltOpcode() == LocalState.getOpcode() &&
+ !isAlternateInstruction(I, LocalState.getMainOp(),
+ LocalState.getAltOp(), *TLI))) {
+ Op1.push_back(V);
+ Op1Indices.set(Idx);
+ continue;
+ }
+ Op2.push_back(V);
+ }
+ Type *ScalarTy = getValueType(VL.front());
+ VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
+ unsigned Opcode0 = LocalState.getOpcode();
+ unsigned Opcode1 = LocalState.getAltOpcode();
+ SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
+ // Enable split node, only if all nodes do not form legal alternate
+ // instruction (like X86 addsub).
+ SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
+ SmallPtrSet<Value *, 4> UOp2(llvm::from_range, Op2);
+ if (UOp1.size() <= 1 || UOp2.size() <= 1 ||
+ TTI->isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask) ||
+ !hasFullVectorsOrPowerOf2(*TTI, Op1.front()->getType(), Op1.size()) ||
+ !hasFullVectorsOrPowerOf2(*TTI, Op2.front()->getType(), Op2.size()))
+ return false;
+ // Enable split node, only if all nodes are power-of-2/full registers.
+ unsigned Op1Cnt = 0, Op2Cnt = Op1.size();
+ for (unsigned Idx : seq<unsigned>(VL.size())) {
+ if (Op1Indices.test(Idx)) {
+ ReorderIndices[Op1Cnt] = Idx;
+ ++Op1Cnt;
+ } else {
+ ReorderIndices[Op2Cnt] = Idx;
+ ++Op2Cnt;
+ }
+ }
+ if (isIdentityOrder(ReorderIndices))
+ ReorderIndices.clear();
+ SmallVector<int> Mask;
+ if (!ReorderIndices.empty())
+ inversePermutation(ReorderIndices, Mask);
+ unsigned NumParts = TTI->getNumberOfParts(VecTy);
+ VectorType *Op1VecTy = getWidenedType(ScalarTy, Op1.size());
+ VectorType *Op2VecTy = getWidenedType(ScalarTy, Op2.size());
+ // Check non-profitable single register ops, which better to be represented
+ // as alternate ops.
+ if (NumParts >= VL.size())
+ return false;
+ if ((LocalState.getMainOp()->isBinaryOp() &&
+ LocalState.getAltOp()->isBinaryOp() &&
+ (LocalState.isShiftOp() || LocalState.isBitwiseLogicOp() ||
+ LocalState.isAddSubLikeOp() || LocalState.isMulDivLikeOp())) ||
+ (LocalState.getMainOp()->isCast() && LocalState.getAltOp()->isCast()) ||
+ (LocalState.getMainOp()->isUnaryOp() &&
+ LocalState.getAltOp()->isUnaryOp())) {
+ constexpr TTI::TargetCostKind Kind = TTI::TCK_RecipThroughput;
+ InstructionCost InsertCost = ::getShuffleCost(
+ *TTI, TTI::SK_InsertSubvector, VecTy, {}, Kind, Op1.size(), Op2VecTy);
+ FixedVectorType *SubVecTy =
+ getWidenedType(ScalarTy, std::max(Op1.size(), Op2.size()));
+ InstructionCost NewShuffleCost =
+ ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc, SubVecTy, Mask, Kind);
+ if (NumParts <= 1 && (Mask.empty() || InsertCost >= NewShuffleCost))
+ return false;
+ InstructionCost OriginalVecOpsCost =
+ TTI->getArithmeticInstrCost(Opcode0, VecTy, Kind) +
+ TTI->getArithmeticInstrCost(Opcode1, VecTy, Kind);
+ SmallVector<int> OriginalMask(VL.size(), PoisonMaskElem);
+ for (unsigned Idx : seq<unsigned>(VL.size())) {
+ if (isa<PoisonValue>(VL[Idx]))
+ continue;
+ OriginalMask[Idx] = Idx + (Op1Indices.test(Idx) ? 0 : VL.size());
+ }
+ InstructionCost OriginalCost =
+ OriginalVecOpsCost + ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc,
+ VecTy, OriginalMask, Kind);
+ InstructionCost NewVecOpsCost =
+ TTI->getArithmeticInstrCost(Opcode0, Op1VecTy, Kind) +
+ TTI->getArithmeticInstrCost(Opcode1, Op2VecTy, Kind);
+ InstructionCost NewCost =
+ NewVecOpsCost + InsertCost +
+ (!VectorizableTree.empty() && VectorizableTree.front()->hasState() &&
+ VectorizableTree.front()->getOpcode() == Instruction::Store
+ ? NewShuffleCost
+ : 0);
+ // If not profitable to split - exit.
+ if (NewCost >= OriginalCost)
+ return false;
+ }
+ return true;
+}
+
void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
const EdgeInfo &UserTreeIdx,
unsigned InterleaveFactor) {
@@ -9265,11 +9388,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
// Tries to build split node.
- constexpr unsigned SmallNodeSize = 4;
- auto TrySplitNode = [&, &TTI = *TTI](unsigned SmallNodeSize,
- const InstructionsState &LocalState) {
- if (VL.size() <= SmallNodeSize ||
- TTI.preferAlternateOpcodeVectorization() || !SplitAlternateInstructions)
+ auto TrySplitNode = [&](const InstructionsState &LocalState) {
+ SmallVector<Value *> Op1, Op2;
+ OrdersType ReorderIndices;
+ if (!canBuildSplitNode(VL, LocalState, Op1, Op2, ReorderIndices))
return false;
// Any value is used in split node already - just gather.
@@ -9283,105 +9405,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
return true;
}
- SmallVector<Value *> Op1, Op2;
- OrdersType ReorderIndices(VL.size(), VL.size());
- SmallBitVector Op1Indices(VL.size());
- for (auto [Idx, V] : enumerate(VL)) {
- auto *I = dyn_cast<Instruction>(V);
- if (!I) {
- Op1.push_back(V);
- Op1Indices.set(Idx);
- continue;
- }
- if ((LocalState.getAltOpcode() != LocalState.getOpcode() &&
- I->getOpcode() == LocalState.getOpcode()) ||
- (LocalState.getAltOpcode() == LocalState.getOpcode() &&
- !isAlternateInstruction(I, LocalState.getMainOp(),
- LocalState.getAltOp(), *TLI))) {
- Op1.push_back(V);
- Op1Indices.set(Idx);
- continue;
- }
- Op2.push_back(V);
- }
- Type *ScalarTy = getValueType(VL.front());
- VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
- unsigned Opcode0 = LocalState.getOpcode();
- unsigned Opcode1 = LocalState.getAltOpcode();
- SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
- // Enable split node, only if all nodes do not form legal alternate
- // instruction (like X86 addsub).
- SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
- SmallPtrSet<Value *, 4> UOp2(llvm::from_range, Op2);
- if (UOp1.size() <= 1 || UOp2.size() <= 1 ||
- TTI.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask) ||
- !hasFullVectorsOrPowerOf2(TTI, Op1.front()->getType(), Op1.size()) ||
- !hasFullVectorsOrPowerOf2(TTI, Op2.front()->getType(), Op2.size()))
- return false;
- // Enable split node, only if all nodes are power-of-2/full registers.
- unsigned Op1Cnt = 0, Op2Cnt = Op1.size();
- for (unsigned Idx : seq<unsigned>(VL.size())) {
- if (Op1Indices.test(Idx)) {
- ReorderIndices[Op1Cnt] = Idx;
- ++Op1Cnt;
- } else {
- ReorderIndices[Op2Cnt] = Idx;
- ++Op2Cnt;
- }
- }
- if (isIdentityOrder(ReorderIndices))
- ReorderIndices.clear();
- SmallVector<int> Mask;
- if (!ReorderIndices.empty())
- inversePermutation(ReorderIndices, Mask);
- unsigned NumParts = TTI.getNumberOfParts(VecTy);
- VectorType *Op1VecTy = getWidenedType(ScalarTy, Op1.size());
- VectorType *Op2VecTy = getWidenedType(ScalarTy, Op2.size());
- // Check non-profitable single register ops, which better to be represented
- // as alternate ops.
- if (NumParts >= VL.size())
- return false;
- if ((LocalState.getMainOp()->isBinaryOp() &&
- LocalState.getAltOp()->isBinaryOp() &&
- (LocalState.isShiftOp() || LocalState.isBitwiseLogicOp() ||
- LocalState.isAddSubLikeOp() || LocalState.isMulDivLikeOp())) ||
- (LocalState.getMainOp()->isCast() && LocalState.getAltOp()->isCast()) ||
- (LocalState.getMainOp()->isUnaryOp() &&
- LocalState.getAltOp()->isUnaryOp())) {
- constexpr TTI::TargetCostKind Kind = TTI::TCK_RecipThroughput;
- InstructionCost InsertCost = ::getShuffleCost(
- TTI, TTI::SK_InsertSubvector, VecTy, {}, Kind, Op1.size(), Op2VecTy);
- FixedVectorType *SubVecTy =
- getWidenedType(ScalarTy, std::max(Op1.size(), Op2.size()));
- InstructionCost NewShuffleCost =
- ::getShuffleCost(TTI, TTI::SK_PermuteTwoSrc, SubVecTy, Mask, Kind);
- if (NumParts <= 1 && (Mask.empty() || InsertCost >= NewShuffleCost))
- return false;
- InstructionCost OriginalVecOpsCost =
- TTI.getArithmeticInstrCost(Opcode0, VecTy, Kind) +
- TTI.getArithmeticInstrCost(Opcode1, VecTy, Kind);
- SmallVector<int> OriginalMask(VL.size(), PoisonMaskElem);
- for (unsigned Idx : seq<unsigned>(VL.size())) {
- if (isa<PoisonValue>(VL[Idx]))
- continue;
- OriginalMask[Idx] = Idx + (Op1Indices.test(Idx) ? 0 : VL.size());
- }
- InstructionCost OriginalCost =
- OriginalVecOpsCost + ::getShuffleCost(TTI, TTI::SK_PermuteTwoSrc,
- VecTy, OriginalMask, Kind);
- InstructionCost NewVecOpsCost =
- TTI.getArithmeticInstrCost(Opcode0, Op1VecTy, Kind) +
- TTI.getArithmeticInstrCost(Opcode1, Op2VecTy, Kind);
- InstructionCost NewCost =
- NewVecOpsCost + InsertCost +
- (!VectorizableTree.empty() && VectorizableTree.front()->hasState() &&
- VectorizableTree.front()->getOpcode() == Instruction::Store
- ? NewShuffleCost
- : 0);
- // If not profitable to split - exit.
- if (NewCost >= OriginalCost)
- return false;
- }
SmallVector<Value *> NewVL(VL.size());
copy(Op1, NewVL.begin());
@@ -9497,8 +9520,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (!S) {
auto [MainOp, AltOp] = getMainAltOpsNoStateVL(VL);
// Last chance to try to vectorize alternate node.
- if (MainOp && AltOp &&
- TrySplitNode(SmallNodeSize, InstructionsState(MainOp, AltOp)))
+ if (MainOp && AltOp && TrySplitNode(InstructionsState(MainOp, AltOp)))
return;
}
LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O, small shuffle. \n");
@@ -9622,7 +9644,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
// FIXME: investigate if there are profitable cases for VL.size() <= 4.
- if (S.isAltShuffle() && TrySplitNode(SmallNodeSize, S))
+ if (S.isAltShuffle() && TrySplitNode(S))
return;
// Check that every instruction appears once in this bundle.
@@ -9657,8 +9679,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
if (!BundlePtr || (*BundlePtr && !*BundlePtr.value())) {
LLVM_DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n");
// Last chance to try to vectorize alternate node.
- if (S.isAltShuffle() && ReuseShuffleIndices.empty() &&
- TrySplitNode(SmallNodeSize, S))
+ if (S.isAltShuffle() && ReuseShuffleIndices.empty() && TrySplitNode(S))
return;
auto Invalid = ScheduleBundle::invalid();
newTreeEntry(VL, Invalid /*not vectorized*/, S, UserTreeIdx,
``````````
</details>
https://github.com/llvm/llvm-project/pull/134896
More information about the llvm-commits
mailing list