[llvm] 4d64a2b - [LV] Refactor vector function variant selection to prepare for uniform args (#68879)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 20 05:30:08 PST 2023


Author: Graham Hunter
Date: 2023-11-20T13:30:03Z
New Revision: 4d64a2bcd31818aa0aed0ce9e6b64898a6f0eb55

URL: https://github.com/llvm/llvm-project/commit/4d64a2bcd31818aa0aed0ce9e6b64898a6f0eb55
DIFF: https://github.com/llvm/llvm-project/commit/4d64a2bcd31818aa0aed0ce9e6b64898a6f0eb55.diff

LOG: [LV] Refactor vector function variant selection to prepare for uniform args (#68879)

Parameters marked as uniform take a scalar value, assuming the value is
invariant in the scalar loop. In order to support this, we need to stop
asking for a vector function variant with a default shape assuming that all
arguments will become vector arguments, and instead consider all available
variants and their parameter types.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 834fa9ea8f57fea..d52afc5508be947 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7012,39 +7012,52 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
 
       // Find the cost of vectorizing the call, if we can find a suitable
       // vector variant of the function.
-      InstructionCost MaskCost = 0;
-      VFShape Shape = VFShape::get(*CI, VF, MaskRequired);
-      bool UsesMask = MaskRequired;
-      Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
-      // If we want an unmasked vector function but can't find one matching the
-      // VF, maybe we can find vector function that does use a mask and
-      // synthesize an all-true mask.
-      if (!VecFunc && !MaskRequired) {
-        Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true);
-        VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
-        // If we found one, add in the cost of creating a mask
-        if (VecFunc) {
-          UsesMask = true;
-          MaskCost = TTI.getShuffleCost(
-              TargetTransformInfo::SK_Broadcast,
-              VectorType::get(IntegerType::getInt1Ty(
-                                  VecFunc->getFunctionType()->getContext()),
-                              VF));
-        }
-      }
+      bool UsesMask = false;
+      VFInfo FuncInfo;
+      Function *VecFunc = nullptr;
+      // Search through any available variants for one we can use at this VF.
+      for (VFInfo &Info : VFDatabase::getMappings(*CI)) {
+        // Must match requested VF.
+        if (Info.Shape.VF != VF)
+          continue;
 
-      std::optional<unsigned> MaskPos = std::nullopt;
-      if (VecFunc && UsesMask) {
-        for (const VFInfo &Info : VFDatabase::getMappings(*CI))
-          if (Info.Shape == Shape) {
-            assert(Info.isMasked() && "Vector function info shape mismatch");
-            MaskPos = Info.getParamIndexForOptionalMask().value();
+        // Must take a mask argument if one is required
+        if (MaskRequired && !Info.isMasked())
+          continue;
+
+        // Check that all parameter kinds are supported
+        bool ParamsOk = true;
+        for (VFParameter Param : Info.Shape.Parameters) {
+          switch (Param.ParamKind) {
+          case VFParamKind::Vector:
+            break;
+          case VFParamKind::GlobalPredicate:
+            UsesMask = true;
+            break;
+          default:
+            ParamsOk = false;
             break;
           }
+        }
+
+        if (!ParamsOk)
+          continue;
 
-        assert(MaskPos.has_value() && "Unable to find mask parameter index");
+        // Found a suitable candidate, stop here.
+        VecFunc = CI->getModule()->getFunction(Info.VectorName);
+        FuncInfo = Info;
+        break;
       }
 
+      // Add in the cost of synthesizing a mask if one wasn't required.
+      InstructionCost MaskCost = 0;
+      if (VecFunc && UsesMask && !MaskRequired)
+        MaskCost = TTI.getShuffleCost(
+            TargetTransformInfo::SK_Broadcast,
+            VectorType::get(IntegerType::getInt1Ty(
+                                VecFunc->getFunctionType()->getContext()),
+                            VF));
+
       if (TLI && VecFunc && !CI->isNoBuiltin())
         VectorCost =
             TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost;
@@ -7068,7 +7081,8 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
         Decision = CM_IntrinsicCall;
       }
 
-      setCallWideningDecision(CI, VF, Decision, VecFunc, IID, MaskPos, Cost);
+      setCallWideningDecision(CI, VF, Decision, VecFunc, IID,
+                              FuncInfo.getParamIndexForOptionalMask(), Cost);
     }
   }
 }

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 6b3218dca1b18b0..2c3fabdbc85619d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -504,6 +504,9 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
          "DbgInfoIntrinsic should have been dropped during VPlan construction");
   State.setDebugLocFrom(CI.getDebugLoc());
 
+  FunctionType *VFTy = nullptr;
+  if (Variant)
+    VFTy = Variant->getFunctionType();
   for (unsigned Part = 0; Part < State.UF; ++Part) {
     SmallVector<Type *, 2> TysForDecl;
     // Add return type if intrinsic is overloaded on it.
@@ -515,12 +518,15 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
     for (const auto &I : enumerate(operands())) {
       // Some intrinsics have a scalar argument - don't replace it with a
       // vector.
+      // Some vectorized function variants may also take a scalar argument,
+      // e.g. linear parameters for pointers.
       Value *Arg;
-      if (VectorIntrinsicID == Intrinsic::not_intrinsic ||
-          !isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index()))
-        Arg = State.get(I.value(), Part);
-      else
+      if ((VFTy && !VFTy->getParamType(I.index())->isVectorTy()) ||
+          (VectorIntrinsicID != Intrinsic::not_intrinsic &&
+           isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index())))
         Arg = State.get(I.value(), VPIteration(0, 0));
+      else
+        Arg = State.get(I.value(), Part);
       if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
         TysForDecl.push_back(Arg->getType());
       Args.push_back(Arg);


        


More information about the llvm-commits mailing list