[llvm] [SVE] Don't require lookup when demangling vector function mappings (PR #72260)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 22 03:22:47 PST 2023

@@ -273,49 +275,88 @@ ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
   return ParseRet::None;
-#ifndef NDEBUG
-// Verify the assumtion that all vectors in the signature of a vector
-// function have the same number of elements.
-bool verifyAllVectorsHaveSameWidth(FunctionType *Signature) {
-  SmallVector<VectorType *, 2> VecTys;
-  if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
-    VecTys.push_back(RetTy);
-  for (auto *Ty : Signature->params())
-    if (auto *VTy = dyn_cast<VectorType>(Ty))
-      VecTys.push_back(VTy);
-  if (VecTys.size() <= 1)
-    return true;
-  assert(VecTys.size() > 1 && "Invalid number of elements.");
-  const ElementCount EC = VecTys[0]->getElementCount();
-  return llvm::all_of(llvm::drop_begin(VecTys), [&EC](VectorType *VTy) {
-    return (EC == VTy->getElementCount());
-  });
+// Given a type, return the size in bits if it is a supported element type
+// for vectorized function calls, or nullopt if not.
+std::optional<unsigned> getSizeFromScalarType(Type *Ty) {
+  // The scalar function should only take scalar arguments.
+  if (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isPointerTy())
+    return std::nullopt;
+  unsigned SizeInBits = Ty->getPrimitiveSizeInBits();
+  switch (SizeInBits) {
+  // Legal power-of-two scalars are supported.
+  case 64:
+  case 32:
+  case 16:
+  case 8:
+    return SizeInBits;
+  case 0:
+    // We're assuming a 64b pointer size here for SVE; if another non-64b
+    // target adds support for scalable vectors, we may need DataLayout to
+    // determine the size.
+    if (Ty->isPointerTy())
+      return 64;
+    break;
+  default:
+    break;
+  }
+  return std::nullopt;
-#endif // NDEBUG
-// Extract the VectorizationFactor from a given function signature,
-// under the assumtion that all vectors have the same number of
-// elements, i.e. same ElementCount.Min.
-ElementCount getECFromSignature(FunctionType *Signature) {
-  assert(verifyAllVectorsHaveSameWidth(Signature) &&
-         "Invalid vector signature.");
-  if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
-    return RetTy->getElementCount();
-  for (auto *Ty : Signature->params())
-    if (auto *VTy = dyn_cast<VectorType>(Ty))
-      return VTy->getElementCount();
-  return ElementCount::getFixed(/*Min=*/1);
+// Extract the VectorizationFactor from a given function signature, based
+// on the widest scalar element types that will become vector parameters.
+getScalableECFromSignature(FunctionType *Signature, const VFISAKind ISA,
+                           const SmallVectorImpl<VFParameter> &Params) {
+  // Look up the minimum known register size in order to calculate minimum VF.
+  // Only AArch64 SVE is supported at present.
+  unsigned MinRegSizeInBits;
+  switch (ISA) {
+  case VFISAKind::SVE:
+    MinRegSizeInBits = 128;
+    break;
+  default:
+    return std::nullopt;
+  }
+  unsigned WidestTypeInBits = 0;
+  for (auto &Param : Params) {
+    // Check any parameters that will be widened to vectors. Uniform or linear
+    // parameters may be misleading for determining the VF of a given function.
+    if (Param.ParamKind == VFParamKind::Vector) {
+      // If the scalar function doesn't actually have a corresponding argument,
+      // reject the mapping.
+      if (Param.ParamPos + 1 > Signature->getNumParams())
huntergr-arm wrote:

I tried your suggestion to move the check to the parameter parsing code, and unfortunately discovered that many of the unit tests are broken -- VectorFunctionABITest.cpp has a method called invokeParser, which takes a mangled name, a scalar function name, and a scalar function type. However, it's often invoked with the mangled name + the vector variant name and no type, using the default 'void()'.

So I'll avoid doing that in this PR and raise another later to fix this.


More information about the llvm-commits mailing list