[llvm] [SVE] Don't require lookup when demangling vector function mappings (PR #72260)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 23 03:52:52 PST 2023


https://github.com/huntergr-arm updated https://github.com/llvm/llvm-project/pull/72260

>From 4ce19105db7f840286ced7ca4af98a37a8e5f8b2 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Mon, 13 Nov 2023 16:03:47 +0000
Subject: [PATCH 1/3] [SVE] Don't require lookup when demangling vector
 function mappings

We can determine the VF from a combination of the mangled name (which
indicates the arguments that take vectors) and the element sizes of
the arguments for the scalar function the mapping has been established
for.

The assert when demangling fails has been removed in favour of just
not adding the mapping, which prevents the crash seen in
https://github.com/llvm/llvm-project/issues/71892

This patch also stops using _LLVM_ as an ISA for scalable vector tests,
since there aren't defined rules for the way vector arguments should be
handled (e.g. packed vs. unpacked representation).
---
 llvm/include/llvm/Analysis/VectorUtils.h      |   4 +-
 llvm/lib/Analysis/VFABIDemangling.cpp         | 177 +++++++++++-------
 llvm/lib/Analysis/VectorUtils.cpp             |  15 +-
 llvm/lib/Transforms/Utils/ModuleUtils.cpp     |   2 +-
 .../LoopVectorize/AArch64/masked-call.ll      |   8 +-
 .../LoopVectorize/AArch64/scalable-call.ll    |   6 +-
 .../LoopVectorize/AArch64/sve-vfabi.ll        | 110 +++++++++++
 .../AArch64/wider-VF-for-callinst.ll          |   2 +-
 .../vfabi-demangler-fuzzer.cpp                |  24 ++-
 .../Analysis/VectorFunctionABITest.cpp        |  55 +++---
 10 files changed, 279 insertions(+), 124 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7947648aaddd4ea..99b31fd0ca4b659 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -180,7 +180,7 @@ static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
 /// Vectorization Factor of scalable vector functions from their
 /// respective IR declarations.
 std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
-                                          const Module &M);
+                                          const CallInst &CI);
 
 /// Retrieve the `VFParamKind` from a string token.
 VFParamKind getVFParamKindFromString(const StringRef Token);
@@ -227,7 +227,7 @@ class VFDatabase {
       return;
     for (const auto &MangledName : ListOfStrings) {
       const std::optional<VFInfo> Shape =
-          VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
+          VFABI::tryDemangleForVFABI(MangledName, CI);
       // A match is found via scalar and vector names, and also by
       // ensuring that the variant described in the attribute has a
       // corresponding definition or declaration of the vector
diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp
index 1e2d1db4e44b9d6..75e0c2fbb02159a 100644
--- a/llvm/lib/Analysis/VFABIDemangling.cpp
+++ b/llvm/lib/Analysis/VFABIDemangling.cpp
@@ -66,16 +66,18 @@ ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
 /// vector length. On success, the `<vlen>` token is removed from
 /// the input string `ParseString`.
 ///
-ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) {
+ParseRet tryParseVLEN(StringRef &ParseString,
+                      std::optional<unsigned> &ParsedVF) {
   if (ParseString.consume_front("x")) {
-    // Set VF to 0, to be later adjusted to a value grater than zero
-    // by looking at the signature of the vector function with
-    // `getECFromSignature`.
-    VF = 0;
-    IsScalable = true;
+    // We can't determine the VF of a scalable vector by looking at the vlen
+    // string (just 'x'), so say we successfully parsed it but return a nullopt
+    // so that the caller knows it must look at the arguments to determine
+    // the minimum VF based on types.
+    ParsedVF = std::nullopt;
     return ParseRet::OK;
   }
 
+  unsigned VF = 0;
   if (ParseString.consumeInteger(10, VF))
     return ParseRet::Error;
 
@@ -83,7 +85,7 @@ ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) {
   if (VF == 0)
     return ParseRet::Error;
 
-  IsScalable = false;
+  ParsedVF = VF;
   return ParseRet::OK;
 }
 
@@ -273,49 +275,88 @@ ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
   return ParseRet::None;
 }
 
-#ifndef NDEBUG
-// Verify the assumtion that all vectors in the signature of a vector
-// function have the same number of elements.
-bool verifyAllVectorsHaveSameWidth(FunctionType *Signature) {
-  SmallVector<VectorType *, 2> VecTys;
-  if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
-    VecTys.push_back(RetTy);
-  for (auto *Ty : Signature->params())
-    if (auto *VTy = dyn_cast<VectorType>(Ty))
-      VecTys.push_back(VTy);
-
-  if (VecTys.size() <= 1)
-    return true;
-
-  assert(VecTys.size() > 1 && "Invalid number of elements.");
-  const ElementCount EC = VecTys[0]->getElementCount();
-  return llvm::all_of(llvm::drop_begin(VecTys), [&EC](VectorType *VTy) {
-    return (EC == VTy->getElementCount());
-  });
+// Given a type, return the size in bits if it is a supported element type
+// for vectorized function calls, or nullopt if not.
+std::optional<unsigned> getSizeFromScalarType(Type *Ty) {
+  // The scalar function should only take scalar arguments.
+  if (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isPointerTy())
+    return std::nullopt;
+
+  unsigned SizeInBits = Ty->getPrimitiveSizeInBits();
+  switch (SizeInBits) {
+  // Legal power-of-two scalars are supported.
+  case 64:
+  case 32:
+  case 16:
+  case 8:
+    return SizeInBits;
+  case 0:
+    // We're assuming a 64b pointer size here for SVE; if another non-64b
+    // target adds support for scalable vectors, we may need DataLayout to
+    // determine the size.
+    if (Ty->isPointerTy())
+      return 64;
+    break;
+  default:
+    break;
+  }
+
+  return std::nullopt;
 }
-#endif // NDEBUG
-
-// Extract the VectorizationFactor from a given function signature,
-// under the assumtion that all vectors have the same number of
-// elements, i.e. same ElementCount.Min.
-ElementCount getECFromSignature(FunctionType *Signature) {
-  assert(verifyAllVectorsHaveSameWidth(Signature) &&
-         "Invalid vector signature.");
-
-  if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
-    return RetTy->getElementCount();
-  for (auto *Ty : Signature->params())
-    if (auto *VTy = dyn_cast<VectorType>(Ty))
-      return VTy->getElementCount();
-
-  return ElementCount::getFixed(/*Min=*/1);
+
+// Extract the VectorizationFactor from a given function signature, based
+// on the widest scalar element types that will become vector parameters.
+std::optional<ElementCount>
+getScalableECFromSignature(FunctionType *Signature, const VFISAKind ISA,
+                           const SmallVectorImpl<VFParameter> &Params) {
+  // Look up the minimum known register size in order to calculate minimum VF.
+  // Only AArch64 SVE is supported at present.
+  unsigned MinRegSizeInBits;
+  switch (ISA) {
+  case VFISAKind::SVE:
+    MinRegSizeInBits = 128;
+    break;
+  default:
+    return std::nullopt;
+  }
+
+  unsigned WidestTypeInBits = 0;
+  for (auto &Param : Params) {
+    // Check any parameters that will be widened to vectors. Uniform or linear
+    // parameters may be misleading for determining the VF of a given function.
+    if (Param.ParamKind == VFParamKind::Vector) {
+      // If the scalar function doesn't actually have a corresponding argument,
+      // reject the mapping.
+      if (Param.ParamPos + 1 > Signature->getNumParams())
+        return std::nullopt;
+      Type *PTy = Signature->getParamType(Param.ParamPos);
+
+      std::optional<unsigned> SizeInBits = getSizeFromScalarType(PTy);
+      if (SizeInBits)
+        WidestTypeInBits = std::max(WidestTypeInBits, *SizeInBits);
+    }
+  }
+
+  // Also check the return type.
+  std::optional<unsigned> ReturnSizeInBits =
+      getSizeFromScalarType(Signature->getReturnType());
+  if (ReturnSizeInBits)
+    WidestTypeInBits = std::max(WidestTypeInBits, *ReturnSizeInBits);
+
+  // SVE bases the VF on the widest element types present, and vector arguments
+  // containing types of that width are always considered to be packed.
+  // Arguments with narrower elements are considered to be unpacked.
+  if (WidestTypeInBits)
+    return ElementCount::getScalable(MinRegSizeInBits / WidestTypeInBits);
+
+  return std::nullopt;
 }
 } // namespace
 
 // Format of the ABI name:
 // _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
 std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
-                                                 const Module &M) {
+                                                 const CallInst &CI) {
   const StringRef OriginalName = MangledName;
   // Assume there is no custom name <redirection>, and therefore the
   // vector name consists of
@@ -338,9 +379,8 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
     return std::nullopt;
 
   // Parse the variable size, starting from <vlen>.
-  unsigned VF;
-  bool IsScalable;
-  if (tryParseVLEN(MangledName, VF, IsScalable) != ParseRet::OK)
+  std::optional<unsigned> ParsedVF;
+  if (tryParseVLEN(MangledName, ParsedVF) != ParseRet::OK)
     return std::nullopt;
 
   // Parse the <parameters>.
@@ -374,6 +414,24 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
   if (Parameters.empty())
     return std::nullopt;
 
+  // Figure out the number of lanes in vectors for this function variant. This
+  // is easy for fixed length, as the vlen encoding just gives us the value
+  // directly. However, if the vlen mangling indicated that this function
+  // variant expects scalable vectors, then we need to figure out the minimum
+  // based on the widest scalar types in vector arguments.
+  std::optional<ElementCount> EC;
+  if (ParsedVF) {
+    // Fixed length VF
+    EC = ElementCount::getFixed(*ParsedVF);
+  } else {
+    // Scalable VF, need to work out the minimum from the element types
+    // in the scalar function arguments.
+    EC = getScalableECFromSignature(CI.getFunctionType(), ISA, Parameters);
+
+    if (!EC)
+      return std::nullopt;
+  }
+
   // Check for the <scalarname> and the optional <redirection>, which
   // are separated from the prefix with "_"
   if (!MangledName.consume_front("_"))
@@ -426,32 +484,7 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
     assert(Parameters.back().ParamKind == VFParamKind::GlobalPredicate &&
            "The global predicate must be the last parameter");
 
-  // Adjust the VF for scalable signatures. The EC.Min is not encoded
-  // in the name of the function, but it is encoded in the IR
-  // signature of the function. We need to extract this information
-  // because it is needed by the loop vectorizer, which reasons in
-  // terms of VectorizationFactor or ElementCount. In particular, we
-  // need to make sure that the VF field of the VFShape class is never
-  // set to 0.
-  if (IsScalable) {
-    const Function *F = M.getFunction(VectorName);
-    // The declaration of the function must be present in the module
-    // to be able to retrieve its signature.
-    if (!F)
-      return std::nullopt;
-    const ElementCount EC = getECFromSignature(F->getFunctionType());
-    VF = EC.getKnownMinValue();
-  }
-
-  // 1. We don't accept a zero lanes vectorization factor.
-  // 2. We don't accept the demangling if the vector function is not
-  // present in the module.
-  if (VF == 0)
-    return std::nullopt;
-  if (!M.getFunction(VectorName))
-    return std::nullopt;
-
-  const VFShape Shape({ElementCount::get(VF, IsScalable), Parameters});
+  const VFShape Shape({*EC, Parameters});
   return VFInfo({Shape, std::string(ScalarName), std::string(VectorName), ISA});
 }
 
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 4f5db8b7aaf746f..9103bf3c8b82467 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1466,15 +1466,12 @@ void VFABI::getVectorVariantNames(
   S.split(ListAttr, ",");
 
   for (const auto &S : SetVector<StringRef>(ListAttr.begin(), ListAttr.end())) {
-#ifndef NDEBUG
-    LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << S << "'\n");
-    std::optional<VFInfo> Info =
-        VFABI::tryDemangleForVFABI(S, *(CI.getModule()));
-    assert(Info && "Invalid name for a VFABI variant.");
-    assert(CI.getModule()->getFunction(Info->VectorName) &&
-           "Vector function is missing.");
-#endif
-    VariantMappings.push_back(std::string(S));
+    std::optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, CI);
+    if (Info && CI.getModule()->getFunction(Info->VectorName)) {
+      LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "'\n");
+      VariantMappings.push_back(std::string(S));
+    } else
+      LLVM_DEBUG(dbgs() << "VFABI: Invalid mapping '" << S << "'\n");
   }
 }
 
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index a3737d428a00b5f..b75e60a3553cf60 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -346,7 +346,7 @@ void VFABI::setVectorVariantNames(CallInst *CI,
 #ifndef NDEBUG
   for (const std::string &VariantMapping : VariantMappings) {
     LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n");
-    std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *M);
+    std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *CI);
     assert(VI && "Cannot add an invalid VFABI name.");
     assert(M->getNamedValue(VI->VectorName) &&
            "Cannot add variant to attribute: "
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll b/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
index 19a970fb0716f0f..28962dfba89248a 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
@@ -744,8 +744,8 @@ declare <vscale x 2 x i64> @foo_uniform(i64, <vscale x 2 x i1>)
 declare <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64>, <vscale x 2 x i1>)
 declare <vscale x 2 x i64> @foo_vector_nomask(<vscale x 2 x i64>)
 
-attributes #0 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Mxv_foo(foo_vector),_ZGV_LLVM_Mxu_foo(foo_uniform)" }
-attributes #1 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Mxv_foo(foo_vector)" }
-attributes #2 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Nxv_foo(foo_vector_nomask)" }
-attributes #3 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Nxv_foo(foo_vector_nomask),_ZGV_LLVM_Mxv_foo(foo_vector)" }
+attributes #0 = { nounwind "vector-function-abi-variant"="_ZGVsMxv_foo(foo_vector),_ZGVsMxu_foo(foo_uniform)" }
+attributes #1 = { nounwind "vector-function-abi-variant"="_ZGVsMxv_foo(foo_vector)" }
+attributes #2 = { nounwind "vector-function-abi-variant"="_ZGVsNxv_foo(foo_vector_nomask)" }
+attributes #3 = { nounwind "vector-function-abi-variant"="_ZGVsNxv_foo(foo_vector_nomask),_ZGVsMxv_foo(foo_vector)" }
 attributes #4 = { "target-features"="+sve" vscale_range(2,16) "no-trapping-math"="false" }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll
index 69c6f84aac53dd6..84f67310021bf1d 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll
@@ -224,9 +224,9 @@ declare <vscale x 2 x i64> @bar_vec(<vscale x 2 x ptr>)
 declare <vscale x 2 x double> @sin_vec_nxv2f64(<vscale x 2 x double>)
 declare <2 x double> @sin_vec_v2f64(<2 x double>)
 
-attributes #0 = { "vector-function-abi-variant"="_ZGV_LLVM_Nxv_foo(foo_vec)" }
-attributes #1 = { "vector-function-abi-variant"="_ZGV_LLVM_Nxv_bar(bar_vec)" }
-attributes #2 = { "vector-function-abi-variant"="_ZGV_LLVM_Nxv_llvm.sin.f64(sin_vec_nxv2f64)" }
+attributes #0 = { "vector-function-abi-variant"="_ZGVsNxv_foo(foo_vec)" }
+attributes #1 = { "vector-function-abi-variant"="_ZGVsNxv_bar(bar_vec)" }
+attributes #2 = { "vector-function-abi-variant"="_ZGVsNxv_llvm.sin.f64(sin_vec_nxv2f64)" }
 attributes #3 = { "vector-function-abi-variant"="_ZGV_LLVM_N2v_llvm.sin.f64(sin_vec_v2f64)" }
 
 !1 = distinct !{!1, !2, !3}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll
new file mode 100644
index 000000000000000..31cf2c4e2db4cdf
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll
@@ -0,0 +1,110 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -passes=loop-vectorize,simplifycfg,instcombine -force-vector-interleave=1 -prefer-predicate-over-epilogue=predicate-dont-vectorize -S | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @test_big_little_params(ptr readonly %a, ptr readonly %b, ptr noalias %c) #0 {
+; CHECK-LABEL: define void @test_big_little_params
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], ptr noalias [[C:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 1025)
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 4 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i32, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32.p0(ptr [[TMP0]], i32 4, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i32> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8.p0(ptr [[TMP1]], i32 1, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 4 x i32> @foo_vector(<vscale x 4 x i32> [[WIDE_MASKED_LOAD]], <vscale x 4 x i8> [[WIDE_MASKED_LOAD1]], <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEXT:    call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> [[TMP2]], ptr [[TMP3]], i32 4, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i64 [[TMP4]], 2
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 [[INDEX_NEXT]], i64 1025)
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <vscale x 4 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP6]], label [[VECTOR_BODY]], label [[EXIT:%.*]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %gep_s = getelementptr i32, ptr %a, i64 %iv
+  %load_s = load i32, ptr %gep_s
+  %gep_b = getelementptr i8, ptr %b, i64 %iv
+  %load_b = load i8, ptr %gep_b
+  %call = call i32 @foo_big_little(i32 %load_s, i8 %load_b) #1
+  %arrayidx = getelementptr inbounds i32, ptr %c, i64 %iv
+  store i32 %call, ptr %arrayidx
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond = icmp eq i64 %iv.next, 1025
+  br i1 %exitcond, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+define void @test_little_big_params(ptr readonly %a, ptr readonly %b, ptr noalias %c) #0 {
+; CHECK-LABEL: define void @test_little_big_params
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], ptr noalias [[C:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 1025)
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 2 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 2 x float> @llvm.masked.load.nxv2f32.p0(ptr [[TMP0]], i32 4, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x float> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr double, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 2 x double> @llvm.masked.load.nxv2f64.p0(ptr [[TMP1]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x double> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x double> @bar_vector(<vscale x 2 x float> [[WIDE_MASKED_LOAD]], <vscale x 2 x double> [[WIDE_MASKED_LOAD1]], <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds double, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEXT:    call void @llvm.masked.store.nxv2f64.p0(<vscale x 2 x double> [[TMP2]], ptr [[TMP3]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i64 [[TMP4]], 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 [[INDEX_NEXT]], i64 1025)
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <vscale x 2 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP6]], label [[VECTOR_BODY]], label [[FOR_COND_CLEANUP:%.*]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %gep_f = getelementptr float, ptr %a, i64 %iv
+  %load_f = load float, ptr %gep_f
+  %gep_d = getelementptr double, ptr %b, i64 %iv
+  %load_d = load double, ptr %gep_d
+  %call = call double @bar_little_big(float %load_f, double %load_d) #2
+  %arrayidx = getelementptr inbounds double, ptr %c, i64 %iv
+  store double %call, ptr %arrayidx
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond = icmp eq i64 %iv.next, 1025
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:
+  ret void
+}
+
+;; TODO: Test uniform and linear parameters when they are properly supported,
+;;       especially a variant with no vector parameters so the return type
+;;       must be used to find the VF.
+
+;; Scalar functions
+declare i32 @foo_big_little(i32, i8)
+declare double @bar_little_big(float, double)
+
+;; Vector function variants
+declare <vscale x 4 x i32> @foo_vector(<vscale x 4 x i32>, <vscale x 4 x i8>, <vscale x 4 x i1>)
+declare <vscale x 2 x double> @bar_vector(<vscale x 2 x float>, <vscale x 2 x double>, <vscale x 2 x i1>)
+
+attributes #0 = { "target-features"="+sve" vscale_range(1,16) }
+attributes #1 = { nounwind "vector-function-abi-variant"="_ZGVsMxvv_foo_big_little(foo_vector)" }
+attributes #2 = { nounwind "vector-function-abi-variant"="_ZGVsMxvv_bar_little_big(bar_vector)" }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll b/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll
index 00e8881426fd8f9..7d095206c6062b4 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll
@@ -111,5 +111,5 @@ for.cond.cleanup:
 declare float @foo(float)
 declare <vscale x 4 x float> @foo_vector(<vscale x 4 x float>, <vscale x 4 x i1>)
 
-attributes #0 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Mxv_foo(foo_vector)" }
+attributes #0 = { nounwind "vector-function-abi-variant"="_ZGVsMxv_foo(foo_vector)" }
 attributes #1 = { "target-features"="+sve" vscale_range(1,16) "no-trapping-math"="false" }
diff --git a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
index b0b80131bf48f48..09dc15c9e36667b 100644
--- a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
+++ b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
@@ -27,15 +27,23 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, size_t Size) {
   // present. We need to make sure we can even invoke
   // `getOrInsertFunction` because such method asserts on strings with
   // zeroes.
-  if (!MangledName.empty() && MangledName.find_first_of(0) == StringRef::npos)
-    M->getOrInsertFunction(
-        MangledName,
-        FunctionType::get(Type::getVoidTy(M->getContext()), false));
-  const auto Info = VFABI::tryDemangleForVFABI(MangledName, *M);
+  // TODO: What is this actually testing? That we don't crash?
+  if (!MangledName.empty() && MangledName.find_first_of(0) == StringRef::npos) {
+    FunctionType *FTy =
+        FunctionType::get(Type::getVoidTy(M->getContext()), false);
+    FunctionCallee F = M->getOrInsertFunction(MangledName, FTy);
+    // Fake the arguments to the CallInst.
+    SmallVector<Value *> Args;
+    for (Type *ParamTy : FTy->params()) {
+      Args.push_back(Constant::getNullValue(ParamTy));
+    }
+    std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
+    const auto Info = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
 
-  // Do not optimize away the return value. Inspired by
-  // https://github.com/google/benchmark/blob/main/include/benchmark/benchmark.h#L307-L345
-  asm volatile("" : : "r,m"(Info) : "memory");
+    // Do not optimize away the return value. Inspired by
+    // https://github.com/google/benchmark/blob/main/include/benchmark/benchmark.h#L307-L345
+    asm volatile("" : : "r,m"(Info) : "memory");
+  }
 
   return 0;
 }
diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
index a4c6b2143fc662c..b233124ebfa6ed1 100644
--- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp
+++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
@@ -25,11 +25,11 @@ class VFABIParserTest : public ::testing::Test {
     EXPECT_NE(M.get(), nullptr) << "Loading an invalid module.\n "
                                 << Err.getMessage() << "\n";
     Type *Ty = parseType(IRType, Err, *(M.get()));
-    FunctionType *FTy = dyn_cast<FunctionType>(Ty);
+    FTy = dyn_cast<FunctionType>(Ty);
     EXPECT_NE(FTy, nullptr) << "Invalid function type string: " << IRType
                             << "\n"
                             << Err.getMessage() << "\n";
-    FunctionCallee F = M->getOrInsertFunction(Name, FTy);
+    F = M->getOrInsertFunction(Name, FTy);
     EXPECT_NE(F.getCallee(), nullptr)
         << "The function must be present in the module\n";
     // Reset the VFInfo
@@ -40,7 +40,9 @@ class VFABIParserTest : public ::testing::Test {
   LLVMContext Ctx;
   SMDiagnostic Err;
   std::unique_ptr<Module> M;
-  //  CallInst *CI;
+  FunctionType *FTy;
+  FunctionCallee F;
+
 protected:
   // Referencies to the parser output field.
   ElementCount &VF = Info.Shape.VF;
@@ -65,16 +67,22 @@ class VFABIParserTest : public ::testing::Test {
   // generic fixed-length case can use as signature `void()`.
   //
   bool invokeParser(const StringRef MangledName,
-                    const StringRef VectorName = "",
+                    const StringRef ScalarName = "",
                     const StringRef IRType = "void()") {
     StringRef Name = MangledName;
-    if (!VectorName.empty())
-      Name = VectorName;
+    if (!ScalarName.empty())
+      Name = ScalarName;
     // Reset the VFInfo and the Module to be able to invoke
     // `invokeParser` multiple times in the same test.
     reset(Name, IRType);
 
-    const auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, *(M.get()));
+    // Fake the arguments to the CallInst.
+    SmallVector<Value *> Args;
+    for (Type *ParamTy : FTy->params()) {
+      Args.push_back(Constant::getNullValue(ParamTy->getScalarType()));
+    }
+    std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
+    const auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
     if (OptInfo) {
       Info = *OptInfo;
       return true;
@@ -204,10 +212,8 @@ TEST_F(VFABIParserTest, LinearWithCompileTimeNegativeStep) {
 }
 
 TEST_F(VFABIParserTest, ParseScalableSVE) {
-  EXPECT_TRUE(invokeParser(
-      "_ZGVsMxv_sin(custom_vg)", "custom_vg",
-      "<vscale x 2 x i32>(<vscale x 2 x i32>, <vscale x 2 x i1>)"));
-  EXPECT_EQ(VF, ElementCount::getScalable(2));
+  EXPECT_TRUE(invokeParser("_ZGVsMxv_sin(custom_vg)", "sin", "i32(i32)"));
+  EXPECT_EQ(VF, ElementCount::getScalable(4));
   EXPECT_TRUE(IsMasked());
   EXPECT_EQ(ISA, VFISAKind::SVE);
   EXPECT_EQ(ScalarName, "sin");
@@ -495,12 +501,16 @@ TEST_F(VFABIParserTest, ParseMaskingLLVM) {
 }
 
 TEST_F(VFABIParserTest, ParseScalableMaskingLLVM) {
-  EXPECT_TRUE(invokeParser(
-      "_ZGV_LLVM_Mxv_sin(custom_vector_sin)", "custom_vector_sin",
-      "<vscale x 2 x i32> (<vscale x 2 x i32>, <vscale x 2 x i1>)"));
+  EXPECT_FALSE(
+      invokeParser("_ZGV_LLVM_Mxv_sin(custom_vector_sin)", "sin", "i32(i32)"));
+}
+
+TEST_F(VFABIParserTest, ParseScalableMaskingSVE) {
+  EXPECT_TRUE(
+      invokeParser("_ZGVsMxv_sin(custom_vector_sin)", "sin", "i32(i32)"));
   EXPECT_TRUE(IsMasked());
-  EXPECT_EQ(VF, ElementCount::getScalable(2));
-  EXPECT_EQ(ISA, VFISAKind::LLVM);
+  EXPECT_EQ(VF, ElementCount::getScalable(4));
+  EXPECT_EQ(ISA, VFISAKind::SVE);
   EXPECT_EQ(Parameters.size(), (unsigned)2);
   EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
   EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
@@ -508,13 +518,12 @@ TEST_F(VFABIParserTest, ParseScalableMaskingLLVM) {
   EXPECT_EQ(VectorName, "custom_vector_sin");
 }
 
-TEST_F(VFABIParserTest, ParseScalableMaskingLLVMSincos) {
-  EXPECT_TRUE(invokeParser("_ZGV_LLVM_Mxvl8l8_sincos(custom_vector_sincos)",
-                           "custom_vector_sincos",
-                           "void(<vscale x 2 x double>, double *, double *)"));
+TEST_F(VFABIParserTest, ParseScalableMaskingSVESincos) {
+  EXPECT_TRUE(invokeParser("_ZGVsMxvl8l8_sincos(custom_vector_sincos)",
+                           "sincos", "void(double, double *, double *)"));
   EXPECT_EQ(VF, ElementCount::getScalable(2));
   EXPECT_TRUE(IsMasked());
-  EXPECT_EQ(ISA, VFISAKind::LLVM);
+  EXPECT_EQ(ISA, VFISAKind::SVE);
   EXPECT_EQ(Parameters.size(), (unsigned)4);
   EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
   EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::OMP_Linear, 8}));
@@ -581,9 +590,7 @@ TEST_F(VFABIParserTest, ParseScalableRequiresDeclaration) {
   // The parser succeds only when the correct function definition of
   // `custom_vg` is added to the module.
   EXPECT_FALSE(invokeParser(MangledName));
-  EXPECT_TRUE(invokeParser(
-      MangledName, "custom_vg",
-      "<vscale x 4 x double>(<vscale x 4 x double>, <vscale x 4 x i1>)"));
+  EXPECT_TRUE(invokeParser(MangledName, "sin", "double(double)"));
 }
 
 TEST_F(VFABIParserTest, ZeroIsInvalidVLEN) {

>From 6336196759d99f5bab80b2be2584e81ec8b6dbd5 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Wed, 22 Nov 2023 12:04:22 +0000
Subject: [PATCH 2/3] Address comments, make local functions static, add more
 tests

---
 llvm/include/llvm/Analysis/VectorUtils.h      |  10 +-
 llvm/lib/Analysis/VFABIDemangling.cpp         | 184 +++++++++---------
 llvm/lib/Analysis/VectorUtils.cpp             |   3 +-
 .../Analysis/VectorFunctionABITest.cpp        |  24 +++
 4 files changed, 128 insertions(+), 93 deletions(-)

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 99b31fd0ca4b659..d54b63fd4f5328f 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -174,11 +174,11 @@ static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
 ///
 /// \param MangledName -> input string in the format
 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
-/// \param M -> Module used to retrieve informations about the vector
-/// function that are not possible to retrieve from the mangled
-/// name. At the moment, this parameter is needed only to retrieve the
-/// Vectorization Factor of scalable vector functions from their
-/// respective IR declarations.
+/// \param CI -> A call to the scalar function which we're trying to find
+/// a vectorized variant for. This is required to determine the vectorization
+/// factor for scalable vectors, since the mangled name doesn't encode that;
+/// it needs to be derived from the widest element types of vector arguments
+/// or return values.
 std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
                                           const CallInst &CI);
 
diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp
index 75e0c2fbb02159a..9d86d3411399887 100644
--- a/llvm/lib/Analysis/VFABIDemangling.cpp
+++ b/llvm/lib/Analysis/VFABIDemangling.cpp
@@ -7,9 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <limits>
 
 using namespace llvm;
 
+#define DEBUG_TYPE "vfabi-demangling"
+
 namespace {
 /// Utilities for the Vector Function ABI name parser.
 
@@ -21,8 +26,9 @@ enum class ParseRet {
 };
 
 /// Extracts the `<isa>` information from the mangled string, and
-/// sets the `ISA` accordingly.
-ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) {
+/// sets the `ISA` accordingly. If successful, the <isa> token is removed
+/// from the input string `MangledName`.
+static ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) {
   if (MangledName.empty())
     return ParseRet::Error;
 
@@ -45,9 +51,9 @@ ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) {
 }
 
 /// Extracts the `<mask>` information from the mangled string, and
-/// sets `IsMasked` accordingly. The input string `MangledName` is
-/// left unmodified.
-ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
+/// sets `IsMasked` accordingly. If successful, the <mask> token is removed
+/// from the input string `MangledName`.
+static ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
   if (MangledName.consume_front("M")) {
     IsMasked = true;
     return ParseRet::OK;
@@ -62,18 +68,24 @@ ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
 }
 
 /// Extract the `<vlen>` information from the mangled string, and
-/// sets `VF` accordingly. A `<vlen> == "x"` token is interpreted as a scalable
-/// vector length. On success, the `<vlen>` token is removed from
-/// the input string `ParseString`.
-///
-ParseRet tryParseVLEN(StringRef &ParseString,
-                      std::optional<unsigned> &ParsedVF) {
+/// sets `ParsedVF` accordingly. A `<vlen> == "x"` token is interpreted as a
+/// scalable vector length and the boolean is set to true, otherwise a nonzero
+/// unsigned integer will be directly used as a VF. On success, the `<vlen>`
+/// token is removed from the input string `ParseString`.
+static ParseRet tryParseVLEN(StringRef &ParseString, VFISAKind ISA,
+                             std::pair<unsigned, bool> &ParsedVF) {
   if (ParseString.consume_front("x")) {
+    // SVE is the only scalable ISA currently supported.
+    if (ISA != VFISAKind::SVE) {
+      LLVM_DEBUG(dbgs() << "Vector function variant declared with scalable VF "
+                        << "but ISA is not SVE\n");
+      return ParseRet::Error;
+    }
     // We can't determine the VF of a scalable vector by looking at the vlen
-    // string (just 'x'), so say we successfully parsed it but return a nullopt
-    // so that the caller knows it must look at the arguments to determine
-    // the minimum VF based on types.
-    ParsedVF = std::nullopt;
+    // string (just 'x'), so say we successfully parsed it but return a 'true'
+    // for the scalable field with an invalid VF field so that we know to look
+    // up the actual VF based on element types from the parameters or return.
+    ParsedVF = {0, true};
     return ParseRet::OK;
   }
 
@@ -85,7 +97,7 @@ ParseRet tryParseVLEN(StringRef &ParseString,
   if (VF == 0)
     return ParseRet::Error;
 
-  ParsedVF = VF;
+  ParsedVF = {VF, false};
   return ParseRet::OK;
 }
 
@@ -101,9 +113,9 @@ ParseRet tryParseVLEN(StringRef &ParseString,
 ///
 /// The function expects <token> to be one of "ls", "Rs", "Us" or
 /// "Ls".
-ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
-                                            VFParamKind &PKind, int &Pos,
-                                            const StringRef Token) {
+static ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
+                                                   VFParamKind &PKind, int &Pos,
+                                                   const StringRef Token) {
   if (ParseString.consume_front(Token)) {
     PKind = VFABI::getVFParamKindFromString(Token);
     if (ParseString.consumeInteger(10, Pos))
@@ -125,8 +137,9 @@ ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
 /// sets `PKind` to the correspondent enum value, sets `StepOrPos` to
 /// <number>, and return success.  On a syntax error, it return a
 /// parsing error. If nothing is parsed, it returns std::nullopt.
-ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString,
-                                       VFParamKind &PKind, int &StepOrPos) {
+static ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString,
+                                              VFParamKind &PKind,
+                                              int &StepOrPos) {
   ParseRet Ret;
 
   // "ls" <RuntimeStepPos>
@@ -164,9 +177,10 @@ ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString,
 ///
 /// The function expects <token> to be one of "l", "R", "U" or
 /// "L".
-ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString,
-                                        VFParamKind &PKind, int &LinearStep,
-                                        const StringRef Token) {
+static ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString,
+                                               VFParamKind &PKind,
+                                               int &LinearStep,
+                                               const StringRef Token) {
   if (ParseString.consume_front(Token)) {
     PKind = VFABI::getVFParamKindFromString(Token);
     const bool Negate = ParseString.consume_front("n");
@@ -189,8 +203,9 @@ ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString,
 /// sets `PKind` to the correspondent enum value, sets `LinearStep` to
 /// <number>, and return success.  On a syntax error, it return a
 /// parsing error. If nothing is parsed, it returns std::nullopt.
-ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString,
-                                           VFParamKind &PKind, int &StepOrPos) {
+static ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString,
+                                                  VFParamKind &PKind,
+                                                  int &StepOrPos) {
   // "l" {"n"} <CompileTimeStep>
   if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "l") ==
       ParseRet::OK)
@@ -222,8 +237,8 @@ ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString,
 /// sets `PKind` to the correspondent enum value, sets `StepOrPos`
 /// accordingly, and return success.  On a syntax error, it return a
 /// parsing error. If nothing is parsed, it returns std::nullopt.
-ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind,
-                           int &StepOrPos) {
+static ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind,
+                                  int &StepOrPos) {
   if (ParseString.consume_front("v")) {
     PKind = VFParamKind::Vector;
     StepOrPos = 0;
@@ -257,7 +272,7 @@ ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind,
 /// sets `PKind` to the correspondent enum value, sets `StepOrPos`
 /// accordingly, and return success.  On a syntax error, it return a
 /// parsing error. If nothing is parsed, it returns std::nullopt.
-ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
+static ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
   uint64_t Val;
   //    "a" <number>
   if (ParseString.consume_front("a")) {
@@ -275,79 +290,74 @@ ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
   return ParseRet::None;
 }
 
-// Given a type, return the size in bits if it is a supported element type
-// for vectorized function calls, or nullopt if not.
-std::optional<unsigned> getSizeFromScalarType(Type *Ty) {
-  // The scalar function should only take scalar arguments.
-  if (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isPointerTy())
-    return std::nullopt;
-
-  unsigned SizeInBits = Ty->getPrimitiveSizeInBits();
-  switch (SizeInBits) {
-  // Legal power-of-two scalars are supported.
-  case 64:
-  case 32:
-  case 16:
-  case 8:
-    return SizeInBits;
-  case 0:
-    // We're assuming a 64b pointer size here for SVE; if another non-64b
-    // target adds support for scalable vectors, we may need DataLayout to
-    // determine the size.
-    if (Ty->isPointerTy())
-      return 64;
-    break;
-  default:
-    break;
-  }
+// Returns the 'natural' VF for a given scalar element type, assuming a minimum
+// vector size of 128b. This matches AArch64 SVE, which is currently the only
+// scalable architecture with a defined vector function variant name mangling.
+static std::optional<ElementCount> getElementCountForTy(const Type *Ty) {
+  if (Ty->isIntegerTy(64) || Ty->isDoubleTy() || Ty->isPointerTy())
+    return ElementCount::getScalable(2);
+  if (Ty->isIntegerTy(32) || Ty->isFloatTy())
+    return ElementCount::getScalable(4);
+  if (Ty->isIntegerTy(16) || Ty->is16bitFPTy())
+    return ElementCount::getScalable(8);
+  if (Ty->isIntegerTy(8))
+    return ElementCount::getScalable(16);
 
   return std::nullopt;
 }
 
 // Extract the VectorizationFactor from a given function signature, based
 // on the widest scalar element types that will become vector parameters.
-std::optional<ElementCount>
-getScalableECFromSignature(FunctionType *Signature, const VFISAKind ISA,
+static std::optional<ElementCount>
+getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
                            const SmallVectorImpl<VFParameter> &Params) {
   // Look up the minimum known register size in order to calculate minimum VF.
   // Only AArch64 SVE is supported at present.
-  unsigned MinRegSizeInBits;
-  switch (ISA) {
-  case VFISAKind::SVE:
-    MinRegSizeInBits = 128;
-    break;
-  default:
-    return std::nullopt;
-  }
+  assert(ISA == VFISAKind::SVE &&
+         "Scalable VF decoding only implemented for SVE\n");
 
-  unsigned WidestTypeInBits = 0;
+  // Start with a very wide EC and drop when we find smaller ECs based on type.
+  ElementCount MinEC =
+        ElementCount::getScalable(std::numeric_limits<unsigned int>::max());
   for (auto &Param : Params) {
-    // Check any parameters that will be widened to vectors. Uniform or linear
-    // parameters may be misleading for determining the VF of a given function.
+    // Only vector parameters are used when determining the VF; uniform or
+    // linear are left as scalars, so do not affect VF.
     if (Param.ParamKind == VFParamKind::Vector) {
       // If the scalar function doesn't actually have a corresponding argument,
       // reject the mapping.
-      if (Param.ParamPos + 1 > Signature->getNumParams())
+      if (Param.ParamPos >= Signature->getNumParams())
         return std::nullopt;
       Type *PTy = Signature->getParamType(Param.ParamPos);
 
-      std::optional<unsigned> SizeInBits = getSizeFromScalarType(PTy);
-      if (SizeInBits)
-        WidestTypeInBits = std::max(WidestTypeInBits, *SizeInBits);
+      std::optional<ElementCount> EC = getElementCountForTy(PTy);
+      // If we have an unknown scalar element type we can't find a reasonable
+      // VF.
+      if (!EC)
+        return std::nullopt;
+
+      // Find the smallest VF, based on the widest scalar type.
+      if (ElementCount::isKnownLT(*EC, MinEC))
+        MinEC = *EC;
     }
   }
 
-  // Also check the return type.
-  std::optional<unsigned> ReturnSizeInBits =
-      getSizeFromScalarType(Signature->getReturnType());
-  if (ReturnSizeInBits)
-    WidestTypeInBits = std::max(WidestTypeInBits, *ReturnSizeInBits);
+  // Also check the return type if not void.
+  Type *RetTy = Signature->getReturnType();
+  if (!RetTy->isVoidTy()) {
+    std::optional<ElementCount> ReturnEC = getElementCountForTy(RetTy);
+    // If we have an unknown scalar element type we can't find a reasonable VF.
+    if (!ReturnEC)
+      return std::nullopt;
+    if (ElementCount::isKnownLT(*ReturnEC, MinEC))
+      MinEC = *ReturnEC;
+  }
 
-  // SVE bases the VF on the widest element types present, and vector arguments
-  // containing types of that width are always considered to be packed.
-  // Arguments with narrower elements are considered to be unpacked.
-  if (WidestTypeInBits)
-    return ElementCount::getScalable(MinRegSizeInBits / WidestTypeInBits);
+  // The SVE Vector function call ABI bases the VF on the widest element types
+  // present, and vector arguments containing types of that width are always
+  // considered to be packed. Arguments with narrower elements are considered
+  // to be unpacked.
+  if (MinEC.getKnownMinValue() < std::numeric_limits<unsigned int>::max())
+    return MinEC;
 
   return std::nullopt;
 }
@@ -379,8 +389,8 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
     return std::nullopt;
 
   // Parse the variable size, starting from <vlen>.
-  std::optional<unsigned> ParsedVF;
-  if (tryParseVLEN(MangledName, ParsedVF) != ParseRet::OK)
+  std::pair<unsigned, bool> ParsedVF;
+  if (tryParseVLEN(MangledName, ISA, ParsedVF) != ParseRet::OK)
     return std::nullopt;
 
   // Parse the <parameters>.
@@ -420,16 +430,16 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
   // variant expects scalable vectors, then we need to figure out the minimum
   // based on the widest scalar types in vector arguments.
   std::optional<ElementCount> EC;
-  if (ParsedVF) {
-    // Fixed length VF
-    EC = ElementCount::getFixed(*ParsedVF);
-  } else {
+  if (ParsedVF.second) {
     // Scalable VF, need to work out the minimum from the element types
     // in the scalar function arguments.
     EC = getScalableECFromSignature(CI.getFunctionType(), ISA, Parameters);
 
     if (!EC)
       return std::nullopt;
+  } else {
+    // Fixed length VF
+    EC = ElementCount::getFixed(ParsedVF.first);
   }
 
   // Check for the <scalarname> and the optional <redirection>, which
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 9103bf3c8b82467..926eca80eacf4a5 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1468,7 +1468,8 @@ void VFABI::getVectorVariantNames(
   for (const auto &S : SetVector<StringRef>(ListAttr.begin(), ListAttr.end())) {
     std::optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, CI);
     if (Info && CI.getModule()->getFunction(Info->VectorName)) {
-      LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "'\n");
+      LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "' for " <<
+                 CI << "\n");
       VariantMappings.push_back(std::string(S));
     } else
       LLVM_DEBUG(dbgs() << "VFABI: Invalid mapping '" << S << "'\n");
diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
index b233124ebfa6ed1..e2cb7129c7b365a 100644
--- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp
+++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
@@ -533,6 +533,30 @@ TEST_F(VFABIParserTest, ParseScalableMaskingSVESincos) {
   EXPECT_EQ(VectorName, "custom_vector_sincos");
 }
 
+// Make sure that we get the correct VF if the return type is wider than any
+// parameter type.
+TEST_F(VFABIParserTest, ParseWiderReturnTypeSVE) {
+  EXPECT_TRUE(
+    invokeParser("_ZGVsMxvv_foo(vector_foo)", "foo", "i64(i32, i32)"));
+  EXPECT_EQ(VF, ElementCount::getScalable(2));
+}
+
+// Make sure we handle void return types.
+TEST_F(VFABIParserTest, ParseVoidReturnTypeSVE) {
+  EXPECT_TRUE(invokeParser("_ZGVsMxv_foo(vector_foo)", "foo", "void(i16)"));
+  EXPECT_EQ(VF, ElementCount::getScalable(8));
+}
+
+// Make sure we reject unsupported parameter types.
+TEST_F(VFABIParserTest, ParseUnsupportedElementTypeSVE) {
+  EXPECT_FALSE(invokeParser("_ZGVsMxv_foo(vector_foo)", "foo", "void(i128)"));
+}
+
+// Make sure we reject unsupported return types
+TEST_F(VFABIParserTest, ParseUnsupportedReturnTypeSVE) {
+  EXPECT_FALSE(invokeParser("_ZGVsMxv_foo(vector_foo)", "foo", "fp128(float)"));
+}
+
 class VFABIAttrTest : public testing::Test {
 protected:
   void SetUp() override {

>From 979678e708ec5aa3c3f42533024a707d4147ca7e Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Thu, 23 Nov 2023 11:39:29 +0000
Subject: [PATCH 3/3] Move assert, fix comments, reformat

---
 llvm/lib/Analysis/VFABIDemangling.cpp         | 38 +++++++++----------
 llvm/lib/Analysis/VectorUtils.cpp             |  4 +-
 .../Analysis/VectorFunctionABITest.cpp        |  2 +-
 3 files changed, 21 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp
index 9d86d3411399887..88f61cfeb9ba4e5 100644
--- a/llvm/lib/Analysis/VFABIDemangling.cpp
+++ b/llvm/lib/Analysis/VFABIDemangling.cpp
@@ -290,10 +290,18 @@ static ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
   return ParseRet::None;
 }
 
-// Returns the 'natural' VF for a given scalar element type, assuming a minimum
-// vector size of 128b. This matches AArch64 SVE, which is currently the only
-// scalable architecture with a defined vector function variant name mangling.
-static std::optional<ElementCount> getElementCountForTy(const Type *Ty) {
+// Returns the 'natural' VF for a given scalar element type, based on the
+// current architecture.
+//
+// For SVE (currently the only scalable architecture with a defined name
+// mangling), we assume a minimum vector size of 128b and return a VF based on
+// the number of elements of the given type which would fit in such a vector.
+static std::optional<ElementCount> getElementCountForTy(const VFISAKind ISA,
+                                                        const Type *Ty) {
+  // Only AArch64 SVE is supported at present.
+  assert(ISA == VFISAKind::SVE &&
+         "Scalable VF decoding only implemented for SVE\n");
+
   if (Ty->isIntegerTy(64) || Ty->isDoubleTy() || Ty->isPointerTy())
     return ElementCount::getScalable(2);
   if (Ty->isIntegerTy(32) || Ty->isFloatTy())
@@ -311,14 +319,9 @@ static std::optional<ElementCount> getElementCountForTy(const Type *Ty) {
 static std::optional<ElementCount>
 getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
                            const SmallVectorImpl<VFParameter> &Params) {
-  // Look up the minimum known register size in order to calculate minimum VF.
-  // Only AArch64 SVE is supported at present.
-  assert(ISA == VFISAKind::SVE &&
-         "Scalable VF decoding only implemented for SVE\n");
-
   // Start with a very wide EC and drop when we find smaller ECs based on type.
   ElementCount MinEC =
-        ElementCount::getScalable(std::numeric_limits<unsigned int>::max());
+      ElementCount::getScalable(std::numeric_limits<unsigned int>::max());
   for (auto &Param : Params) {
     // Only vector parameters are used when determining the VF; uniform or
     // linear are left as scalars, so do not affect VF.
@@ -329,7 +332,7 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
         return std::nullopt;
       Type *PTy = Signature->getParamType(Param.ParamPos);
 
-      std::optional<ElementCount> EC = getElementCountForTy(PTy);
+      std::optional<ElementCount> EC = getElementCountForTy(ISA, PTy);
       // If we have an unknown scalar element type we can't find a reasonable
       // VF.
       if (!EC)
@@ -344,7 +347,7 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
   // Also check the return type if not void.
   Type *RetTy = Signature->getReturnType();
   if (!RetTy->isVoidTy()) {
-    std::optional<ElementCount> ReturnEC = getElementCountForTy(RetTy);
+    std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
     // If we have an unknown scalar element type we can't find a reasonable VF.
     if (!ReturnEC)
       return std::nullopt;
@@ -427,20 +430,15 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
   // Figure out the number of lanes in vectors for this function variant. This
   // is easy for fixed length, as the vlen encoding just gives us the value
   // directly. However, if the vlen mangling indicated that this function
-  // variant expects scalable vectors, then we need to figure out the minimum
-  // based on the widest scalar types in vector arguments.
+  // variant expects scalable vectors we need to work it out based on the
+  // demangled parameter types and the scalar function signature.
   std::optional<ElementCount> EC;
   if (ParsedVF.second) {
-    // Scalable VF, need to work out the minimum from the element types
-    // in the scalar function arguments.
     EC = getScalableECFromSignature(CI.getFunctionType(), ISA, Parameters);
-
     if (!EC)
       return std::nullopt;
-  } else {
-    // Fixed length VF
+  } else
     EC = ElementCount::getFixed(ParsedVF.first);
-  }
 
   // Check for the <scalarname> and the optional <redirection>, which
   // are separated from the prefix with "_"
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 926eca80eacf4a5..96f39ff7e409ede 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1468,8 +1468,8 @@ void VFABI::getVectorVariantNames(
   for (const auto &S : SetVector<StringRef>(ListAttr.begin(), ListAttr.end())) {
     std::optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, CI);
     if (Info && CI.getModule()->getFunction(Info->VectorName)) {
-      LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "' for " <<
-                 CI << "\n");
+      LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "' for " << CI
+                        << "\n");
       VariantMappings.push_back(std::string(S));
     } else
       LLVM_DEBUG(dbgs() << "VFABI: Invalid mapping '" << S << "'\n");
diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
index e2cb7129c7b365a..e496d87c06de6bc 100644
--- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp
+++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
@@ -537,7 +537,7 @@ TEST_F(VFABIParserTest, ParseScalableMaskingSVESincos) {
 // parameter type.
 TEST_F(VFABIParserTest, ParseWiderReturnTypeSVE) {
   EXPECT_TRUE(
-    invokeParser("_ZGVsMxvv_foo(vector_foo)", "foo", "i64(i32, i32)"));
+      invokeParser("_ZGVsMxvv_foo(vector_foo)", "foo", "i64(i32, i32)"));
   EXPECT_EQ(VF, ElementCount::getScalable(2));
 }
 



More information about the llvm-commits mailing list