[llvm] [LV] Add support for uniform parameters on vectorized function variants (PR #68879)
Maciej Gabka via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 15 01:35:21 PST 2023
================
@@ -7009,39 +7009,60 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
// Find the cost of vectorizing the call, if we can find a suitable
// vector variant of the function.
- InstructionCost MaskCost = 0;
- VFShape Shape = VFShape::get(*CI, VF, MaskRequired);
- bool UsesMask = MaskRequired;
- Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
- // If we want an unmasked vector function but can't find one matching the
- // VF, maybe we can find vector function that does use a mask and
- // synthesize an all-true mask.
- if (!VecFunc && !MaskRequired) {
- Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true);
- VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
- // If we found one, add in the cost of creating a mask
- if (VecFunc) {
- UsesMask = true;
- MaskCost = TTI.getShuffleCost(
- TargetTransformInfo::SK_Broadcast,
- VectorType::get(IntegerType::getInt1Ty(
- VecFunc->getFunctionType()->getContext()),
- VF));
- }
- }
+ bool UsesMask = false;
+ VFInfo FuncInfo;
+ Function *VecFunc = nullptr;
+ // Search through any available variants for one we can use at this VF.
+ for (VFInfo &Info : VFDatabase::getMappings(*CI)) {
+ // Must match requested VF.
+ if (Info.Shape.VF != VF)
+ continue;
- std::optional<unsigned> MaskPos = std::nullopt;
- if (VecFunc && UsesMask) {
- for (const VFInfo &Info : VFDatabase::getMappings(*CI))
- if (Info.Shape == Shape) {
- assert(Info.isMasked() && "Vector function info shape mismatch");
- MaskPos = Info.getParamIndexForOptionalMask().value();
+ // Must take a mask argument if one is required
+ if (MaskRequired && !Info.isMasked())
+ continue;
+
+ // Check that all parameter kinds are supported
+ bool ParamsOk = true;
+ for (VFParameter Param : Info.Shape.Parameters) {
+ switch (Param.ParamKind) {
+ case VFParamKind::Vector:
+ break;
+ case VFParamKind::OMP_Uniform: {
+ Value *ScalarParam = CI->getArgOperand(Param.ParamPos);
+ // Make sure the scalar parameter in the loop is invariant.
+ if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(ScalarParam),
+ TheLoop))
+ ParamsOk = false;
+ break;
+ }
----------------
mgabka wrote:
to me looks like the piece of code handling "case VFParamKind::OMP_Uniform: " is the actual new functionality, so you could split this patch into NFC change/refactoring + new functionality+tests demonstrating that it works, what do you think?
https://github.com/llvm/llvm-project/pull/68879
More information about the llvm-commits
mailing list