[llvm] [SVE] Don't require lookup when demangling vector function mappings (PR #72260)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 06:26:26 PST 2023


https://github.com/huntergr-arm created https://github.com/llvm/llvm-project/pull/72260

We can determine the VF from a combination of the mangled name (which
indicates the arguments that take vectors) and the element sizes of
the arguments for the scalar function the mapping has been established
for.

The assert when demangling fails has been removed in favour of just
not adding the mapping, which prevents the crash seen in
https://github.com/llvm/llvm-project/issues/71892

This patch also stops using _LLVM_ as an ISA for scalable vector tests,
since there aren't defined rules for the way vector arguments should be
handled (e.g. packed vs. unpacked representation).


>From 4ce19105db7f840286ced7ca4af98a37a8e5f8b2 Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Mon, 13 Nov 2023 16:03:47 +0000
Subject: [PATCH] [SVE] Don't require lookup when demangling vector function
 mappings

We can determine the VF from a combination of the mangled name (which
indicates the arguments that take vectors) and the element sizes of
the arguments for the scalar function the mapping has been established
for.

The assert when demangling fails has been removed in favour of just
not adding the mapping, which prevents the crash seen in
https://github.com/llvm/llvm-project/issues/71892

This patch also stops using _LLVM_ as an ISA for scalable vector tests,
since there aren't defined rules for the way vector arguments should be
handled (e.g. packed vs. unpacked representation).
---
 llvm/include/llvm/Analysis/VectorUtils.h      |   4 +-
 llvm/lib/Analysis/VFABIDemangling.cpp         | 177 +++++++++++-------
 llvm/lib/Analysis/VectorUtils.cpp             |  15 +-
 llvm/lib/Transforms/Utils/ModuleUtils.cpp     |   2 +-
 .../LoopVectorize/AArch64/masked-call.ll      |   8 +-
 .../LoopVectorize/AArch64/scalable-call.ll    |   6 +-
 .../LoopVectorize/AArch64/sve-vfabi.ll        | 110 +++++++++++
 .../AArch64/wider-VF-for-callinst.ll          |   2 +-
 .../vfabi-demangler-fuzzer.cpp                |  24 ++-
 .../Analysis/VectorFunctionABITest.cpp        |  55 +++---
 10 files changed, 279 insertions(+), 124 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7947648aaddd4ea..99b31fd0ca4b659 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -180,7 +180,7 @@ static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
 /// Vectorization Factor of scalable vector functions from their
 /// respective IR declarations.
 std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
-                                          const Module &M);
+                                          const CallInst &CI);
 
 /// Retrieve the `VFParamKind` from a string token.
 VFParamKind getVFParamKindFromString(const StringRef Token);
@@ -227,7 +227,7 @@ class VFDatabase {
       return;
     for (const auto &MangledName : ListOfStrings) {
       const std::optional<VFInfo> Shape =
-          VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
+          VFABI::tryDemangleForVFABI(MangledName, CI);
       // A match is found via scalar and vector names, and also by
       // ensuring that the variant described in the attribute has a
       // corresponding definition or declaration of the vector
diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp
index 1e2d1db4e44b9d6..75e0c2fbb02159a 100644
--- a/llvm/lib/Analysis/VFABIDemangling.cpp
+++ b/llvm/lib/Analysis/VFABIDemangling.cpp
@@ -66,16 +66,18 @@ ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
 /// vector length. On success, the `<vlen>` token is removed from
 /// the input string `ParseString`.
 ///
-ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) {
+ParseRet tryParseVLEN(StringRef &ParseString,
+                      std::optional<unsigned> &ParsedVF) {
   if (ParseString.consume_front("x")) {
-    // Set VF to 0, to be later adjusted to a value grater than zero
-    // by looking at the signature of the vector function with
-    // `getECFromSignature`.
-    VF = 0;
-    IsScalable = true;
+    // We can't determine the VF of a scalable vector by looking at the vlen
+    // string (just 'x'), so say we successfully parsed it but return a nullopt
+    // so that the caller knows it must look at the arguments to determine
+    // the minimum VF based on types.
+    ParsedVF = std::nullopt;
     return ParseRet::OK;
   }
 
+  unsigned VF = 0;
   if (ParseString.consumeInteger(10, VF))
     return ParseRet::Error;
 
@@ -83,7 +85,7 @@ ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) {
   if (VF == 0)
     return ParseRet::Error;
 
-  IsScalable = false;
+  ParsedVF = VF;
   return ParseRet::OK;
 }
 
@@ -273,49 +275,88 @@ ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
   return ParseRet::None;
 }
 
-#ifndef NDEBUG
-// Verify the assumtion that all vectors in the signature of a vector
-// function have the same number of elements.
-bool verifyAllVectorsHaveSameWidth(FunctionType *Signature) {
-  SmallVector<VectorType *, 2> VecTys;
-  if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
-    VecTys.push_back(RetTy);
-  for (auto *Ty : Signature->params())
-    if (auto *VTy = dyn_cast<VectorType>(Ty))
-      VecTys.push_back(VTy);
-
-  if (VecTys.size() <= 1)
-    return true;
-
-  assert(VecTys.size() > 1 && "Invalid number of elements.");
-  const ElementCount EC = VecTys[0]->getElementCount();
-  return llvm::all_of(llvm::drop_begin(VecTys), [&EC](VectorType *VTy) {
-    return (EC == VTy->getElementCount());
-  });
+// Given a type, return the size in bits if it is a supported element type
+// for vectorized function calls, or nullopt if not.
+std::optional<unsigned> getSizeFromScalarType(Type *Ty) {
+  // The scalar function should only take scalar arguments.
+  if (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isPointerTy())
+    return std::nullopt;
+
+  unsigned SizeInBits = Ty->getPrimitiveSizeInBits();
+  switch (SizeInBits) {
+  // Legal power-of-two scalars are supported.
+  case 64:
+  case 32:
+  case 16:
+  case 8:
+    return SizeInBits;
+  case 0:
+    // We're assuming a 64b pointer size here for SVE; if another non-64b
+    // target adds support for scalable vectors, we may need DataLayout to
+    // determine the size.
+    if (Ty->isPointerTy())
+      return 64;
+    break;
+  default:
+    break;
+  }
+
+  return std::nullopt;
 }
-#endif // NDEBUG
-
-// Extract the VectorizationFactor from a given function signature,
-// under the assumtion that all vectors have the same number of
-// elements, i.e. same ElementCount.Min.
-ElementCount getECFromSignature(FunctionType *Signature) {
-  assert(verifyAllVectorsHaveSameWidth(Signature) &&
-         "Invalid vector signature.");
-
-  if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
-    return RetTy->getElementCount();
-  for (auto *Ty : Signature->params())
-    if (auto *VTy = dyn_cast<VectorType>(Ty))
-      return VTy->getElementCount();
-
-  return ElementCount::getFixed(/*Min=*/1);
+
+// Extract the VectorizationFactor from a given function signature, based
+// on the widest scalar element types that will become vector parameters.
+std::optional<ElementCount>
+getScalableECFromSignature(FunctionType *Signature, const VFISAKind ISA,
+                           const SmallVectorImpl<VFParameter> &Params) {
+  // Look up the minimum known register size in order to calculate minimum VF.
+  // Only AArch64 SVE is supported at present.
+  unsigned MinRegSizeInBits;
+  switch (ISA) {
+  case VFISAKind::SVE:
+    MinRegSizeInBits = 128;
+    break;
+  default:
+    return std::nullopt;
+  }
+
+  unsigned WidestTypeInBits = 0;
+  for (auto &Param : Params) {
+    // Check any parameters that will be widened to vectors. Uniform or linear
+    // parameters may be misleading for determining the VF of a given function.
+    if (Param.ParamKind == VFParamKind::Vector) {
+      // If the scalar function doesn't actually have a corresponding argument,
+      // reject the mapping.
+      if (Param.ParamPos + 1 > Signature->getNumParams())
+        return std::nullopt;
+      Type *PTy = Signature->getParamType(Param.ParamPos);
+
+      std::optional<unsigned> SizeInBits = getSizeFromScalarType(PTy);
+      if (SizeInBits)
+        WidestTypeInBits = std::max(WidestTypeInBits, *SizeInBits);
+    }
+  }
+
+  // Also check the return type.
+  std::optional<unsigned> ReturnSizeInBits =
+      getSizeFromScalarType(Signature->getReturnType());
+  if (ReturnSizeInBits)
+    WidestTypeInBits = std::max(WidestTypeInBits, *ReturnSizeInBits);
+
+  // SVE bases the VF on the widest element types present, and vector arguments
+  // containing types of that width are always considered to be packed.
+  // Arguments with narrower elements are considered to be unpacked.
+  if (WidestTypeInBits)
+    return ElementCount::getScalable(MinRegSizeInBits / WidestTypeInBits);
+
+  return std::nullopt;
 }
 } // namespace
 
 // Format of the ABI name:
 // _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
 std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
-                                                 const Module &M) {
+                                                 const CallInst &CI) {
   const StringRef OriginalName = MangledName;
   // Assume there is no custom name <redirection>, and therefore the
   // vector name consists of
@@ -338,9 +379,8 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
     return std::nullopt;
 
   // Parse the variable size, starting from <vlen>.
-  unsigned VF;
-  bool IsScalable;
-  if (tryParseVLEN(MangledName, VF, IsScalable) != ParseRet::OK)
+  std::optional<unsigned> ParsedVF;
+  if (tryParseVLEN(MangledName, ParsedVF) != ParseRet::OK)
     return std::nullopt;
 
   // Parse the <parameters>.
@@ -374,6 +414,24 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
   if (Parameters.empty())
     return std::nullopt;
 
+  // Figure out the number of lanes in vectors for this function variant. This
+  // is easy for fixed length, as the vlen encoding just gives us the value
+  // directly. However, if the vlen mangling indicated that this function
+  // variant expects scalable vectors, then we need to figure out the minimum
+  // based on the widest scalar types in vector arguments.
+  std::optional<ElementCount> EC;
+  if (ParsedVF) {
+    // Fixed length VF
+    EC = ElementCount::getFixed(*ParsedVF);
+  } else {
+    // Scalable VF, need to work out the minimum from the element types
+    // in the scalar function arguments.
+    EC = getScalableECFromSignature(CI.getFunctionType(), ISA, Parameters);
+
+    if (!EC)
+      return std::nullopt;
+  }
+
   // Check for the <scalarname> and the optional <redirection>, which
   // are separated from the prefix with "_"
   if (!MangledName.consume_front("_"))
@@ -426,32 +484,7 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
     assert(Parameters.back().ParamKind == VFParamKind::GlobalPredicate &&
            "The global predicate must be the last parameter");
 
-  // Adjust the VF for scalable signatures. The EC.Min is not encoded
-  // in the name of the function, but it is encoded in the IR
-  // signature of the function. We need to extract this information
-  // because it is needed by the loop vectorizer, which reasons in
-  // terms of VectorizationFactor or ElementCount. In particular, we
-  // need to make sure that the VF field of the VFShape class is never
-  // set to 0.
-  if (IsScalable) {
-    const Function *F = M.getFunction(VectorName);
-    // The declaration of the function must be present in the module
-    // to be able to retrieve its signature.
-    if (!F)
-      return std::nullopt;
-    const ElementCount EC = getECFromSignature(F->getFunctionType());
-    VF = EC.getKnownMinValue();
-  }
-
-  // 1. We don't accept a zero lanes vectorization factor.
-  // 2. We don't accept the demangling if the vector function is not
-  // present in the module.
-  if (VF == 0)
-    return std::nullopt;
-  if (!M.getFunction(VectorName))
-    return std::nullopt;
-
-  const VFShape Shape({ElementCount::get(VF, IsScalable), Parameters});
+  const VFShape Shape({*EC, Parameters});
   return VFInfo({Shape, std::string(ScalarName), std::string(VectorName), ISA});
 }
 
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 4f5db8b7aaf746f..9103bf3c8b82467 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1466,15 +1466,12 @@ void VFABI::getVectorVariantNames(
   S.split(ListAttr, ",");
 
   for (const auto &S : SetVector<StringRef>(ListAttr.begin(), ListAttr.end())) {
-#ifndef NDEBUG
-    LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << S << "'\n");
-    std::optional<VFInfo> Info =
-        VFABI::tryDemangleForVFABI(S, *(CI.getModule()));
-    assert(Info && "Invalid name for a VFABI variant.");
-    assert(CI.getModule()->getFunction(Info->VectorName) &&
-           "Vector function is missing.");
-#endif
-    VariantMappings.push_back(std::string(S));
+    std::optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, CI);
+    if (Info && CI.getModule()->getFunction(Info->VectorName)) {
+      LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "'\n");
+      VariantMappings.push_back(std::string(S));
+    } else
+      LLVM_DEBUG(dbgs() << "VFABI: Invalid mapping '" << S << "'\n");
   }
 }
 
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index a3737d428a00b5f..b75e60a3553cf60 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -346,7 +346,7 @@ void VFABI::setVectorVariantNames(CallInst *CI,
 #ifndef NDEBUG
   for (const std::string &VariantMapping : VariantMappings) {
     LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n");
-    std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *M);
+    std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *CI);
     assert(VI && "Cannot add an invalid VFABI name.");
     assert(M->getNamedValue(VI->VectorName) &&
            "Cannot add variant to attribute: "
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll b/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
index 19a970fb0716f0f..28962dfba89248a 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
@@ -744,8 +744,8 @@ declare <vscale x 2 x i64> @foo_uniform(i64, <vscale x 2 x i1>)
 declare <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64>, <vscale x 2 x i1>)
 declare <vscale x 2 x i64> @foo_vector_nomask(<vscale x 2 x i64>)
 
-attributes #0 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Mxv_foo(foo_vector),_ZGV_LLVM_Mxu_foo(foo_uniform)" }
-attributes #1 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Mxv_foo(foo_vector)" }
-attributes #2 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Nxv_foo(foo_vector_nomask)" }
-attributes #3 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Nxv_foo(foo_vector_nomask),_ZGV_LLVM_Mxv_foo(foo_vector)" }
+attributes #0 = { nounwind "vector-function-abi-variant"="_ZGVsMxv_foo(foo_vector),_ZGVsMxu_foo(foo_uniform)" }
+attributes #1 = { nounwind "vector-function-abi-variant"="_ZGVsMxv_foo(foo_vector)" }
+attributes #2 = { nounwind "vector-function-abi-variant"="_ZGVsNxv_foo(foo_vector_nomask)" }
+attributes #3 = { nounwind "vector-function-abi-variant"="_ZGVsNxv_foo(foo_vector_nomask),_ZGVsMxv_foo(foo_vector)" }
 attributes #4 = { "target-features"="+sve" vscale_range(2,16) "no-trapping-math"="false" }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll
index 69c6f84aac53dd6..84f67310021bf1d 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/scalable-call.ll
@@ -224,9 +224,9 @@ declare <vscale x 2 x i64> @bar_vec(<vscale x 2 x ptr>)
 declare <vscale x 2 x double> @sin_vec_nxv2f64(<vscale x 2 x double>)
 declare <2 x double> @sin_vec_v2f64(<2 x double>)
 
-attributes #0 = { "vector-function-abi-variant"="_ZGV_LLVM_Nxv_foo(foo_vec)" }
-attributes #1 = { "vector-function-abi-variant"="_ZGV_LLVM_Nxv_bar(bar_vec)" }
-attributes #2 = { "vector-function-abi-variant"="_ZGV_LLVM_Nxv_llvm.sin.f64(sin_vec_nxv2f64)" }
+attributes #0 = { "vector-function-abi-variant"="_ZGVsNxv_foo(foo_vec)" }
+attributes #1 = { "vector-function-abi-variant"="_ZGVsNxv_bar(bar_vec)" }
+attributes #2 = { "vector-function-abi-variant"="_ZGVsNxv_llvm.sin.f64(sin_vec_nxv2f64)" }
 attributes #3 = { "vector-function-abi-variant"="_ZGV_LLVM_N2v_llvm.sin.f64(sin_vec_v2f64)" }
 
 !1 = distinct !{!1, !2, !3}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll
new file mode 100644
index 000000000000000..31cf2c4e2db4cdf
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vfabi.ll
@@ -0,0 +1,110 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -passes=loop-vectorize,simplifycfg,instcombine -force-vector-interleave=1 -prefer-predicate-over-epilogue=predicate-dont-vectorize -S | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define void @test_big_little_params(ptr readonly %a, ptr readonly %b, ptr noalias %c) #0 {
+; CHECK-LABEL: define void @test_big_little_params
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], ptr noalias [[C:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 1025)
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 4 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i32, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 4 x i32> @llvm.masked.load.nxv4i32.p0(ptr [[TMP0]], i32 4, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i32> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 4 x i8> @llvm.masked.load.nxv4i8.p0(ptr [[TMP1]], i32 1, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 4 x i32> @foo_vector(<vscale x 4 x i32> [[WIDE_MASKED_LOAD]], <vscale x 4 x i8> [[WIDE_MASKED_LOAD1]], <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEXT:    call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> [[TMP2]], ptr [[TMP3]], i32 4, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i64 [[TMP4]], 2
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 [[INDEX_NEXT]], i64 1025)
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <vscale x 4 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP6]], label [[VECTOR_BODY]], label [[EXIT:%.*]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %gep_s = getelementptr i32, ptr %a, i64 %iv
+  %load_s = load i32, ptr %gep_s
+  %gep_b = getelementptr i8, ptr %b, i64 %iv
+  %load_b = load i8, ptr %gep_b
+  %call = call i32 @foo_big_little(i32 %load_s, i8 %load_b) #1
+  %arrayidx = getelementptr inbounds i32, ptr %c, i64 %iv
+  store i32 %call, ptr %arrayidx
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond = icmp eq i64 %iv.next, 1025
+  br i1 %exitcond, label %exit, label %for.body
+
+exit:
+  ret void
+}
+
+define void @test_little_big_params(ptr readonly %a, ptr readonly %b, ptr noalias %c) #0 {
+; CHECK-LABEL: define void @test_little_big_params
+; CHECK-SAME: (ptr readonly [[A:%.*]], ptr readonly [[B:%.*]], ptr noalias [[C:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 1025)
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 2 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 2 x float> @llvm.masked.load.nxv2f32.p0(ptr [[TMP0]], i32 4, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x float> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr double, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <vscale x 2 x double> @llvm.masked.load.nxv2f64.p0(ptr [[TMP1]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x double> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x double> @bar_vector(<vscale x 2 x float> [[WIDE_MASKED_LOAD]], <vscale x 2 x double> [[WIDE_MASKED_LOAD1]], <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds double, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEXT:    call void @llvm.masked.store.nxv2f64.p0(<vscale x 2 x double> [[TMP2]], ptr [[TMP3]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw nsw i64 [[TMP4]], 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 [[INDEX_NEXT]], i64 1025)
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <vscale x 2 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP6]], label [[VECTOR_BODY]], label [[FOR_COND_CLEANUP:%.*]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %gep_f = getelementptr float, ptr %a, i64 %iv
+  %load_f = load float, ptr %gep_f
+  %gep_d = getelementptr double, ptr %b, i64 %iv
+  %load_d = load double, ptr %gep_d
+  %call = call double @bar_little_big(float %load_f, double %load_d) #2
+  %arrayidx = getelementptr inbounds double, ptr %c, i64 %iv
+  store double %call, ptr %arrayidx
+  %iv.next = add nuw nsw i64 %iv, 1
+  %exitcond = icmp eq i64 %iv.next, 1025
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:
+  ret void
+}
+
+;; TODO: Test uniform and linear parameters when they are properly supported,
+;;       especially a variant with no vector parameters so the return type
+;;       must be used to find the VF.
+
+;; Scalar functions
+declare i32 @foo_big_little(i32, i8)
+declare double @bar_little_big(float, double)
+
+;; Vector function variants
+declare <vscale x 4 x i32> @foo_vector(<vscale x 4 x i32>, <vscale x 4 x i8>, <vscale x 4 x i1>)
+declare <vscale x 2 x double> @bar_vector(<vscale x 2 x float>, <vscale x 2 x double>, <vscale x 2 x i1>)
+
+attributes #0 = { "target-features"="+sve" vscale_range(1,16) }
+attributes #1 = { nounwind "vector-function-abi-variant"="_ZGVsMxvv_foo_big_little(foo_vector)" }
+attributes #2 = { nounwind "vector-function-abi-variant"="_ZGVsMxvv_bar_little_big(bar_vector)" }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll b/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll
index 00e8881426fd8f9..7d095206c6062b4 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/wider-VF-for-callinst.ll
@@ -111,5 +111,5 @@ for.cond.cleanup:
 declare float @foo(float)
 declare <vscale x 4 x float> @foo_vector(<vscale x 4 x float>, <vscale x 4 x i1>)
 
-attributes #0 = { nounwind "vector-function-abi-variant"="_ZGV_LLVM_Mxv_foo(foo_vector)" }
+attributes #0 = { nounwind "vector-function-abi-variant"="_ZGVsMxv_foo(foo_vector)" }
 attributes #1 = { "target-features"="+sve" vscale_range(1,16) "no-trapping-math"="false" }
diff --git a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
index b0b80131bf48f48..09dc15c9e36667b 100644
--- a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
+++ b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
@@ -27,15 +27,23 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, size_t Size) {
   // present. We need to make sure we can even invoke
   // `getOrInsertFunction` because such method asserts on strings with
   // zeroes.
-  if (!MangledName.empty() && MangledName.find_first_of(0) == StringRef::npos)
-    M->getOrInsertFunction(
-        MangledName,
-        FunctionType::get(Type::getVoidTy(M->getContext()), false));
-  const auto Info = VFABI::tryDemangleForVFABI(MangledName, *M);
+  // TODO: What is this actually testing? That we don't crash?
+  if (!MangledName.empty() && MangledName.find_first_of(0) == StringRef::npos) {
+    FunctionType *FTy =
+        FunctionType::get(Type::getVoidTy(M->getContext()), false);
+    FunctionCallee F = M->getOrInsertFunction(MangledName, FTy);
+    // Fake the arguments to the CallInst.
+    SmallVector<Value *> Args;
+    for (Type *ParamTy : FTy->params()) {
+      Args.push_back(Constant::getNullValue(ParamTy));
+    }
+    std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
+    const auto Info = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
 
-  // Do not optimize away the return value. Inspired by
-  // https://github.com/google/benchmark/blob/main/include/benchmark/benchmark.h#L307-L345
-  asm volatile("" : : "r,m"(Info) : "memory");
+    // Do not optimize away the return value. Inspired by
+    // https://github.com/google/benchmark/blob/main/include/benchmark/benchmark.h#L307-L345
+    asm volatile("" : : "r,m"(Info) : "memory");
+  }
 
   return 0;
 }
diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
index a4c6b2143fc662c..b233124ebfa6ed1 100644
--- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp
+++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
@@ -25,11 +25,11 @@ class VFABIParserTest : public ::testing::Test {
     EXPECT_NE(M.get(), nullptr) << "Loading an invalid module.\n "
                                 << Err.getMessage() << "\n";
     Type *Ty = parseType(IRType, Err, *(M.get()));
-    FunctionType *FTy = dyn_cast<FunctionType>(Ty);
+    FTy = dyn_cast<FunctionType>(Ty);
     EXPECT_NE(FTy, nullptr) << "Invalid function type string: " << IRType
                             << "\n"
                             << Err.getMessage() << "\n";
-    FunctionCallee F = M->getOrInsertFunction(Name, FTy);
+    F = M->getOrInsertFunction(Name, FTy);
     EXPECT_NE(F.getCallee(), nullptr)
         << "The function must be present in the module\n";
     // Reset the VFInfo
@@ -40,7 +40,9 @@ class VFABIParserTest : public ::testing::Test {
   LLVMContext Ctx;
   SMDiagnostic Err;
   std::unique_ptr<Module> M;
-  //  CallInst *CI;
+  FunctionType *FTy;
+  FunctionCallee F;
+
 protected:
   // Referencies to the parser output field.
   ElementCount &VF = Info.Shape.VF;
@@ -65,16 +67,22 @@ class VFABIParserTest : public ::testing::Test {
   // generic fixed-length case can use as signature `void()`.
   //
   bool invokeParser(const StringRef MangledName,
-                    const StringRef VectorName = "",
+                    const StringRef ScalarName = "",
                     const StringRef IRType = "void()") {
     StringRef Name = MangledName;
-    if (!VectorName.empty())
-      Name = VectorName;
+    if (!ScalarName.empty())
+      Name = ScalarName;
     // Reset the VFInfo and the Module to be able to invoke
     // `invokeParser` multiple times in the same test.
     reset(Name, IRType);
 
-    const auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, *(M.get()));
+    // Fake the arguments to the CallInst.
+    SmallVector<Value *> Args;
+    for (Type *ParamTy : FTy->params()) {
+      Args.push_back(Constant::getNullValue(ParamTy->getScalarType()));
+    }
+    std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
+    const auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
     if (OptInfo) {
       Info = *OptInfo;
       return true;
@@ -204,10 +212,8 @@ TEST_F(VFABIParserTest, LinearWithCompileTimeNegativeStep) {
 }
 
 TEST_F(VFABIParserTest, ParseScalableSVE) {
-  EXPECT_TRUE(invokeParser(
-      "_ZGVsMxv_sin(custom_vg)", "custom_vg",
-      "<vscale x 2 x i32>(<vscale x 2 x i32>, <vscale x 2 x i1>)"));
-  EXPECT_EQ(VF, ElementCount::getScalable(2));
+  EXPECT_TRUE(invokeParser("_ZGVsMxv_sin(custom_vg)", "sin", "i32(i32)"));
+  EXPECT_EQ(VF, ElementCount::getScalable(4));
   EXPECT_TRUE(IsMasked());
   EXPECT_EQ(ISA, VFISAKind::SVE);
   EXPECT_EQ(ScalarName, "sin");
@@ -495,12 +501,16 @@ TEST_F(VFABIParserTest, ParseMaskingLLVM) {
 }
 
 TEST_F(VFABIParserTest, ParseScalableMaskingLLVM) {
-  EXPECT_TRUE(invokeParser(
-      "_ZGV_LLVM_Mxv_sin(custom_vector_sin)", "custom_vector_sin",
-      "<vscale x 2 x i32> (<vscale x 2 x i32>, <vscale x 2 x i1>)"));
+  EXPECT_FALSE(
+      invokeParser("_ZGV_LLVM_Mxv_sin(custom_vector_sin)", "sin", "i32(i32)"));
+}
+
+TEST_F(VFABIParserTest, ParseScalableMaskingSVE) {
+  EXPECT_TRUE(
+      invokeParser("_ZGVsMxv_sin(custom_vector_sin)", "sin", "i32(i32)"));
   EXPECT_TRUE(IsMasked());
-  EXPECT_EQ(VF, ElementCount::getScalable(2));
-  EXPECT_EQ(ISA, VFISAKind::LLVM);
+  EXPECT_EQ(VF, ElementCount::getScalable(4));
+  EXPECT_EQ(ISA, VFISAKind::SVE);
   EXPECT_EQ(Parameters.size(), (unsigned)2);
   EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
   EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
@@ -508,13 +518,12 @@ TEST_F(VFABIParserTest, ParseScalableMaskingLLVM) {
   EXPECT_EQ(VectorName, "custom_vector_sin");
 }
 
-TEST_F(VFABIParserTest, ParseScalableMaskingLLVMSincos) {
-  EXPECT_TRUE(invokeParser("_ZGV_LLVM_Mxvl8l8_sincos(custom_vector_sincos)",
-                           "custom_vector_sincos",
-                           "void(<vscale x 2 x double>, double *, double *)"));
+TEST_F(VFABIParserTest, ParseScalableMaskingSVESincos) {
+  EXPECT_TRUE(invokeParser("_ZGVsMxvl8l8_sincos(custom_vector_sincos)",
+                           "sincos", "void(double, double *, double *)"));
   EXPECT_EQ(VF, ElementCount::getScalable(2));
   EXPECT_TRUE(IsMasked());
-  EXPECT_EQ(ISA, VFISAKind::LLVM);
+  EXPECT_EQ(ISA, VFISAKind::SVE);
   EXPECT_EQ(Parameters.size(), (unsigned)4);
   EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
   EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::OMP_Linear, 8}));
@@ -581,9 +590,7 @@ TEST_F(VFABIParserTest, ParseScalableRequiresDeclaration) {
   // The parser succeds only when the correct function definition of
   // `custom_vg` is added to the module.
   EXPECT_FALSE(invokeParser(MangledName));
-  EXPECT_TRUE(invokeParser(
-      MangledName, "custom_vg",
-      "<vscale x 4 x double>(<vscale x 4 x double>, <vscale x 4 x i1>)"));
+  EXPECT_TRUE(invokeParser(MangledName, "sin", "double(double)"));
 }
 
 TEST_F(VFABIParserTest, ZeroIsInvalidVLEN) {



More information about the llvm-commits mailing list