[llvm] [SLP] Check for extracts, being replaced by original scalars, for user nodes (PR #149572)
Gaƫtan Bossu via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 28 04:45:48 PDT 2025
================
@@ -9149,6 +9163,81 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
return {IntrinsicCost, LibCost};
}
+bool BoUpSLP::isProfitableToVectorizeWithNonVecUsers(
+ const InstructionsState &S, const EdgeInfo &UserTreeIdx,
+ ArrayRef<Value *> Scalars, ArrayRef<int> ScalarsMask) {
+ assert(S && "Expected valid instructions state.");
+ // Loads, extracts and geps are immediately scalarizable, so no need to check.
+ if (S.getOpcode() == Instruction::Load ||
+ S.getOpcode() == Instruction::ExtractElement ||
+ S.getOpcode() == Instruction::GetElementPtr)
+ return true;
+ // Check only vectorized users, others scalarized (potentially, at least)
+ // already.
+ if (!UserTreeIdx.UserTE || UserTreeIdx.UserTE->isGather() ||
+ UserTreeIdx.UserTE->State == TreeEntry::SplitVectorize)
+ return true;
+ // PHI nodes may have cyclic deps, so cannot check here.
+ if (UserTreeIdx.UserTE->getOpcode() == Instruction::PHI)
+ return true;
+ // Do not check root reduction nodes, they do not have non-vectorized users.
+ if (UserIgnoreList && UserTreeIdx.UserTE->Idx == 0)
+ return true;
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ ArrayRef<Value *> VL = UserTreeIdx.UserTE->Scalars;
+ Type *UserScalarTy = getValueType(VL.front());
+ if (!isValidElementType(UserScalarTy))
+ return true;
+ Type *ScalarTy = getValueType(Scalars.front());
+ if (!isValidElementType(ScalarTy))
+ return true;
+ // Ignore subvectors extracts.
+ if (UserScalarTy->isVectorTy())
+ return true;
+ auto *UserVecTy =
+ getWidenedType(UserScalarTy, UserTreeIdx.UserTE->getVectorFactor());
+ APInt DemandedElts = APInt::getZero(UserTreeIdx.UserTE->getVectorFactor());
+ // Check the external uses and check, if vector node + extracts is not
+ // profitable for the vectorization.
+ InstructionCost UserScalarsCost = 0;
+ for (Value *V : VL) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ continue;
+ if (areAllUsersVectorized(I, UserIgnoreList))
+ continue;
+ DemandedElts.setBit(UserTreeIdx.UserTE->findLaneForValue(V));
+ UserScalarsCost += TTI->getInstructionCost(I, CostKind);
+ }
+ // No non-vectorized users - success.
+ if (DemandedElts.isZero())
+ return true;
+ // If extracts are cheaper than the original scalars - success.
+ InstructionCost ExtractCost =
+ ::getScalarizationOverhead(*TTI, UserScalarTy, UserVecTy, DemandedElts,
+ /*Insert=*/false, /*Extract=*/true, CostKind);
+ if (ExtractCost <= UserScalarsCost)
+ return true;
----------------
gbossu wrote:
The lambda might be small, but it still represents one more closure to understand before being able to understand the whole function.
Something like below, hides away implementation details that aren't necessary to understand `isProfitableToVectorizeWithNonVecUsers`:
```
ScalarUsersCosts UsersCosts = UserTreeIdx.UserTE->getScalarUsersCost(...);
return UsersCosts.ExtractCost < UsersCost.ScalarizationCost;
```
It is clearer than exposing the whole implementation, especially with a "capture-all" lambda:
```
auto AreExtractsCheaperThanScalars = [&]() {
// If extracts are cheaper than the original scalars - success.
InstructionCost ExtractCost = ::getScalarizationOverhead(
*TTI, UserScalarTy, UserVecTy, DemandedElts,
/*Insert=*/false, /*Extract=*/true, CostKind);
if (ExtractCost <= UserScalarsCost)
return true;
SmallPtrSet<Value *, 4> CheckedExtracts;
InstructionCost NodeCost =
getEntryCost(UserTreeIdx.UserTE, {}, CheckedExtracts);
// The node is profitable for vectorization - success.
if (ExtractCost <= NodeCost)
return true;
auto *VecTy = getWidenedType(ScalarTy, VL.size());
InstructionCost ScalarsCost = ::getScalarizationOverhead(
*TTI, ScalarTy, VecTy, APInt::getAllOnes(VL.size()),
/*Insert=*/true, /*Extract=*/false, CostKind);
if (!Mask.empty())
ScalarsCost +=
getShuffleCost(*TTI, TTI::SK_PermuteSingleSrc, VecTy, Mask, CostKind);
return ExtractCost < UserScalarsCost + ScalarsCost;
};
// User extracts are cheaper than user scalars + immediate scalars - success.
return AreExtractsCheaperThanScalars();
```
I think spending a bit of time to refactor is worth the effort, as it makes the code base more accessible.
https://github.com/llvm/llvm-project/pull/149572
More information about the llvm-commits
mailing list