[llvm] [NFC] Replace CallInst with FunctionType in VFABI, VFShape API (PR #74569)

Paschalis Mpeis via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 5 23:30:46 PST 2023


https://github.com/paschalis-mpeis created https://github.com/llvm/llvm-project/pull/74569

This minor simplification was applied to VFShape::getScalarShape, VFShape::get, and VFABI::tryDemangleForVFABI methods.

>From 033ab19409d69ff4ea6257ae796b9029b14fb9c5 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 6 Dec 2023 07:13:50 +0000
Subject: [PATCH] [NFC] Replace CallInst with FunctionType in VFABI, VFShape
 API

This minor simplification was applied to VFShape::getScalarShape,
VFShape::get, and VFABI::tryDemangleForVFABI methods.
---
 llvm/include/llvm/Analysis/VectorUtils.h      | 30 +++++++++----------
 llvm/lib/Analysis/VFABIDemangling.cpp         |  4 +--
 llvm/lib/Analysis/VectorUtils.cpp             |  3 +-
 llvm/lib/Transforms/Utils/ModuleUtils.cpp     |  3 +-
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 15 ++++++----
 .../vfabi-demangler-fuzzer.cpp                |  9 +-----
 .../Analysis/VectorFunctionABITest.cpp        |  3 +-
 llvm/unittests/Analysis/VectorUtilsTest.cpp   |  6 ++--
 8 files changed, 36 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index d54b63fd4f532..55a6aa645a86e 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -94,24 +94,24 @@ struct VFShape {
     assert(hasValidParameterList() && "Invalid parameter list");
   }
 
-  // Retrieve the VFShape that can be used to map a (scalar) function to itself,
-  // with VF = 1.
-  static VFShape getScalarShape(const CallInst &CI) {
-    return VFShape::get(CI, ElementCount::getFixed(1),
+  /// Retrieve the VFShape that can be used to map a scalar function to itself,
+  /// with VF = 1.
+  static VFShape getScalarShape(const FunctionType *FTy) {
+    return VFShape::get(FTy, ElementCount::getFixed(1),
                         /*HasGlobalPredicate*/ false);
   }
 
-  // Retrieve the basic vectorization shape of the function, where all
-  // parameters are mapped to VFParamKind::Vector with \p EC
-  // lanes. Specifies whether the function has a Global Predicate
-  // argument via \p HasGlobalPred.
-  static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
+  /// Retrieve the basic vectorization shape of the function, where all
+  /// parameters are mapped to VFParamKind::Vector with \p EC lanes. Specifies
+  /// whether the function has a Global Predicate argument via \p HasGlobalPred.
+  static VFShape get(const FunctionType *FTy, ElementCount EC,
+                     bool HasGlobalPred) {
     SmallVector<VFParameter, 8> Parameters;
-    for (unsigned I = 0; I < CI.arg_size(); ++I)
+    for (unsigned I = 0; I < FTy->getNumParams(); ++I)
       Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
     if (HasGlobalPred)
       Parameters.push_back(
-          VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
+          VFParameter({FTy->getNumParams(), VFParamKind::GlobalPredicate}));
 
     return {EC, Parameters};
   }
@@ -174,13 +174,13 @@ static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
 ///
 /// \param MangledName -> input string in the format
 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
-/// \param CI -> A call to the scalar function which we're trying to find
+/// \param FTy -> FunctionType of the scalar function which we're trying to find
 /// a vectorized variant for. This is required to determine the vectorization
 /// factor for scalable vectors, since the mangled name doesn't encode that;
 /// it needs to be derived from the widest element types of vector arguments
 /// or return values.
 std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
-                                          const CallInst &CI);
+                                          const FunctionType *FTy);
 
 /// Retrieve the `VFParamKind` from a string token.
 VFParamKind getVFParamKindFromString(const StringRef Token);
@@ -227,7 +227,7 @@ class VFDatabase {
       return;
     for (const auto &MangledName : ListOfStrings) {
       const std::optional<VFInfo> Shape =
-          VFABI::tryDemangleForVFABI(MangledName, CI);
+          VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
       // A match is found via scalar and vector names, and also by
       // ensuring that the variant described in the attribute has a
       // corresponding definition or declaration of the vector
@@ -276,7 +276,7 @@ class VFDatabase {
   /// @{
   /// Retrieve the Function with VFShape \p Shape.
   Function *getVectorizedFunction(const VFShape &Shape) const {
-    if (Shape == VFShape::getScalarShape(CI))
+    if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
       return CI.getCalledFunction();
 
     for (const auto &Info : ScalarToVectorMappings)
diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp
index 88f61cfeb9ba4..92af314a41caa 100644
--- a/llvm/lib/Analysis/VFABIDemangling.cpp
+++ b/llvm/lib/Analysis/VFABIDemangling.cpp
@@ -369,7 +369,7 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
 // Format of the ABI name:
 // _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
 std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
-                                                 const CallInst &CI) {
+                                                 const FunctionType *FTy) {
   const StringRef OriginalName = MangledName;
   // Assume there is no custom name <redirection>, and therefore the
   // vector name consists of
@@ -434,7 +434,7 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
   // demangled parameter types and the scalar function signature.
   std::optional<ElementCount> EC;
   if (ParsedVF.second) {
-    EC = getScalableECFromSignature(CI.getFunctionType(), ISA, Parameters);
+    EC = getScalableECFromSignature(FTy, ISA, Parameters);
     if (!EC)
       return std::nullopt;
   } else
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 96f39ff7e409e..91d8c31fa062d 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1466,7 +1466,8 @@ void VFABI::getVectorVariantNames(
   S.split(ListAttr, ",");
 
   for (const auto &S : SetVector<StringRef>(ListAttr.begin(), ListAttr.end())) {
-    std::optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, CI);
+    std::optional<VFInfo> Info =
+        VFABI::tryDemangleForVFABI(S, CI.getFunctionType());
     if (Info && CI.getModule()->getFunction(Info->VectorName)) {
       LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "' for " << CI
                         << "\n");
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index fc42df75875e1..7de0959ca57ef 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -346,7 +346,8 @@ void VFABI::setVectorVariantNames(CallInst *CI,
 #ifndef NDEBUG
   for (const std::string &VariantMapping : VariantMappings) {
     LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n");
-    std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *CI);
+    std::optional<VFInfo> VI =
+        VFABI::tryDemangleForVFABI(VariantMapping, CI->getFunctionType());
     assert(VI && "Cannot add an invalid VFABI name.");
     assert(M->getNamedValue(VI->VectorName) &&
            "Cannot add variant to attribute: "
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index be48465b8e0e4..21c83fcdac193 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -5466,7 +5466,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
     Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
     VFShape Shape = VFShape::get(
-        *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
+        CI->getFunctionType(),
+        ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
         false /*HasGlobalPred*/);
     Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
 
@@ -6461,9 +6462,10 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
   auto IntrinsicCost =
     TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
 
-  auto Shape = VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>(
-                                     VecTy->getNumElements())),
-                            false /*HasGlobalPred*/);
+  auto Shape = VFShape::get(
+      CI->getFunctionType(),
+      ElementCount::getFixed(static_cast<unsigned>(VecTy->getNumElements())),
+      false /*HasGlobalPred*/);
   Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
   auto LibCost = IntrinsicCost;
   if (!CI->isNoBuiltin() && VecFunc) {
@@ -11643,8 +11645,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       Function *CF;
       if (!UseIntrinsic) {
         VFShape Shape =
-            VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>(
-                                  VecTy->getNumElements())),
+            VFShape::get(CI->getFunctionType(),
+                         ElementCount::getFixed(
+                             static_cast<unsigned>(VecTy->getNumElements())),
                          false /*HasGlobalPred*/);
         CF = VFDatabase(*CI).getVectorizedFunction(Shape);
       } else {
diff --git a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
index 09dc15c9e3666..a6ca3bc4484e3 100644
--- a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
+++ b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
@@ -31,14 +31,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, size_t Size) {
   if (!MangledName.empty() && MangledName.find_first_of(0) == StringRef::npos) {
     FunctionType *FTy =
         FunctionType::get(Type::getVoidTy(M->getContext()), false);
-    FunctionCallee F = M->getOrInsertFunction(MangledName, FTy);
-    // Fake the arguments to the CallInst.
-    SmallVector<Value *> Args;
-    for (Type *ParamTy : FTy->params()) {
-      Args.push_back(Constant::getNullValue(ParamTy));
-    }
-    std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
-    const auto Info = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
+    const auto Info = VFABI::tryDemangleForVFABI(MangledName, FTy);
 
     // Do not optimize away the return value. Inspired by
     // https://github.com/google/benchmark/blob/main/include/benchmark/benchmark.h#L307-L345
diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
index e496d87c06de6..81c1807cdcaa8 100644
--- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp
+++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
@@ -82,7 +82,8 @@ class VFABIParserTest : public ::testing::Test {
       Args.push_back(Constant::getNullValue(ParamTy->getScalarType()));
     }
     std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
-    const auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
+    const auto OptInfo =
+        VFABI::tryDemangleForVFABI(MangledName, CI->getFunctionType());
     if (OptInfo) {
       Info = *OptInfo;
       return true;
diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index c7419e0321235..1b3a8b0259f01 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -580,7 +580,7 @@ class VFShapeAPITest : public testing::Test {
   SmallVector<VFParameter, 8> &ExpectedParams = Expected.Parameters;
 
   void buildShape(ElementCount VF, bool HasGlobalPred) {
-    Shape = VFShape::get(*CI, VF, HasGlobalPred);
+    Shape = VFShape::get(CI->getFunctionType(), VF, HasGlobalPred);
   }
 
   bool validParams(ArrayRef<VFParameter> Parameters) {
@@ -619,11 +619,11 @@ TEST_F(VFShapeAPITest, API_buildVFShape) {
 
 TEST_F(VFShapeAPITest, API_getScalarShape) {
   buildShape(/*VF*/ ElementCount::getFixed(1), /*HasGlobalPred*/ false);
-  EXPECT_EQ(VFShape::getScalarShape(*CI), Shape);
+  EXPECT_EQ(VFShape::getScalarShape(CI->getFunctionType()), Shape);
 }
 
 TEST_F(VFShapeAPITest, API_getVectorizedFunction) {
-  VFShape ScalarShape = VFShape::getScalarShape(*CI);
+  VFShape ScalarShape = VFShape::getScalarShape(CI->getFunctionType());
   EXPECT_EQ(VFDatabase(*CI).getVectorizedFunction(ScalarShape),
             M->getFunction("g"));
 



More information about the llvm-commits mailing list