[llvm] [LV] Add support for uniform parameters on vectorized function variants (PR #68879)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 15 03:13:30 PST 2023


================
@@ -7009,39 +7009,60 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
 
       // Find the cost of vectorizing the call, if we can find a suitable
       // vector variant of the function.
-      InstructionCost MaskCost = 0;
-      VFShape Shape = VFShape::get(*CI, VF, MaskRequired);
-      bool UsesMask = MaskRequired;
-      Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
-      // If we want an unmasked vector function but can't find one matching the
-      // VF, maybe we can find vector function that does use a mask and
-      // synthesize an all-true mask.
-      if (!VecFunc && !MaskRequired) {
-        Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true);
-        VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
-        // If we found one, add in the cost of creating a mask
-        if (VecFunc) {
-          UsesMask = true;
-          MaskCost = TTI.getShuffleCost(
-              TargetTransformInfo::SK_Broadcast,
-              VectorType::get(IntegerType::getInt1Ty(
-                                  VecFunc->getFunctionType()->getContext()),
-                              VF));
-        }
-      }
+      bool UsesMask = false;
+      VFInfo FuncInfo;
+      Function *VecFunc = nullptr;
+      // Search through any available variants for one we can use at this VF.
+      for (VFInfo &Info : VFDatabase::getMappings(*CI)) {
+        // Must match requested VF.
+        if (Info.Shape.VF != VF)
+          continue;
 
-      std::optional<unsigned> MaskPos = std::nullopt;
-      if (VecFunc && UsesMask) {
-        for (const VFInfo &Info : VFDatabase::getMappings(*CI))
-          if (Info.Shape == Shape) {
-            assert(Info.isMasked() && "Vector function info shape mismatch");
-            MaskPos = Info.getParamIndexForOptionalMask().value();
+        // Must take a mask argument if one is required
+        if (MaskRequired && !Info.isMasked())
+          continue;
+
+        // Check that all parameter kinds are supported
+        bool ParamsOk = true;
+        for (VFParameter Param : Info.Shape.Parameters) {
+          switch (Param.ParamKind) {
+          case VFParamKind::Vector:
+            break;
+          case VFParamKind::OMP_Uniform: {
+            Value *ScalarParam = CI->getArgOperand(Param.ParamPos);
+            // Make sure the scalar parameter in the loop is invariant.
+            if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(ScalarParam),
+                                              TheLoop))
+              ParamsOk = false;
+            break;
+          }
+          case VFParamKind::GlobalPredicate:
+            UsesMask = true;
+            break;
+          default:
+            ParamsOk = false;
             break;
           }
+        }
+
+        if (!ParamsOk)
+          continue;
 
-        assert(MaskPos.has_value() && "Unable to find mask parameter index");
+        // Found a suitable candidate, stop here.
+        VecFunc = CI->getModule()->getFunction(Info.VectorName);
+        FuncInfo = Info;
+        break;
       }
 
+      // Add in the cost of synthesizing a mask if one wasn't required.
+      InstructionCost MaskCost = 0;
+      if (VecFunc && UsesMask && !MaskRequired)
----------------
huntergr-arm wrote:

That's just refactoring of existing code, not new functionality.

https://github.com/llvm/llvm-project/pull/68879


More information about the llvm-commits mailing list