[llvm] [NFC] Replace CallInst with FunctionType in VFABI, VFShape API (PR #74569)
Paschalis Mpeis via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 5 23:30:46 PST 2023
https://github.com/paschalis-mpeis created https://github.com/llvm/llvm-project/pull/74569
This minor simplification was applied to VFShape::getScalarShape, VFShape::get, and VFABI::tryDemangleForVFABI methods.
>From 033ab19409d69ff4ea6257ae796b9029b14fb9c5 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 6 Dec 2023 07:13:50 +0000
Subject: [PATCH] [NFC] Replace CallInst with FunctionType in VFABI, VFShape
API
This minor simplification was applied to VFShape::getScalarShape,
VFShape::get, and VFABI::tryDemangleForVFABI methods.
---
llvm/include/llvm/Analysis/VectorUtils.h | 30 +++++++++----------
llvm/lib/Analysis/VFABIDemangling.cpp | 4 +--
llvm/lib/Analysis/VectorUtils.cpp | 3 +-
llvm/lib/Transforms/Utils/ModuleUtils.cpp | 3 +-
.../Transforms/Vectorize/SLPVectorizer.cpp | 15 ++++++----
.../vfabi-demangler-fuzzer.cpp | 9 +-----
.../Analysis/VectorFunctionABITest.cpp | 3 +-
llvm/unittests/Analysis/VectorUtilsTest.cpp | 6 ++--
8 files changed, 36 insertions(+), 37 deletions(-)
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index d54b63fd4f532..55a6aa645a86e 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -94,24 +94,24 @@ struct VFShape {
assert(hasValidParameterList() && "Invalid parameter list");
}
- // Retrieve the VFShape that can be used to map a (scalar) function to itself,
- // with VF = 1.
- static VFShape getScalarShape(const CallInst &CI) {
- return VFShape::get(CI, ElementCount::getFixed(1),
+ /// Retrieve the VFShape that can be used to map a scalar function to itself,
+ /// with VF = 1.
+ static VFShape getScalarShape(const FunctionType *FTy) {
+ return VFShape::get(FTy, ElementCount::getFixed(1),
/*HasGlobalPredicate*/ false);
}
- // Retrieve the basic vectorization shape of the function, where all
- // parameters are mapped to VFParamKind::Vector with \p EC
- // lanes. Specifies whether the function has a Global Predicate
- // argument via \p HasGlobalPred.
- static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
+ /// Retrieve the basic vectorization shape of the function, where all
+ /// parameters are mapped to VFParamKind::Vector with \p EC lanes. Specifies
+ /// whether the function has a Global Predicate argument via \p HasGlobalPred.
+ static VFShape get(const FunctionType *FTy, ElementCount EC,
+ bool HasGlobalPred) {
SmallVector<VFParameter, 8> Parameters;
- for (unsigned I = 0; I < CI.arg_size(); ++I)
+ for (unsigned I = 0; I < FTy->getNumParams(); ++I)
Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
if (HasGlobalPred)
Parameters.push_back(
- VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
+ VFParameter({FTy->getNumParams(), VFParamKind::GlobalPredicate}));
return {EC, Parameters};
}
@@ -174,13 +174,13 @@ static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
///
/// \param MangledName -> input string in the format
/// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
-/// \param CI -> A call to the scalar function which we're trying to find
+/// \param FTy -> FunctionType of the scalar function which we're trying to find
/// a vectorized variant for. This is required to determine the vectorization
/// factor for scalable vectors, since the mangled name doesn't encode that;
/// it needs to be derived from the widest element types of vector arguments
/// or return values.
std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
- const CallInst &CI);
+ const FunctionType *FTy);
/// Retrieve the `VFParamKind` from a string token.
VFParamKind getVFParamKindFromString(const StringRef Token);
@@ -227,7 +227,7 @@ class VFDatabase {
return;
for (const auto &MangledName : ListOfStrings) {
const std::optional<VFInfo> Shape =
- VFABI::tryDemangleForVFABI(MangledName, CI);
+ VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
// A match is found via scalar and vector names, and also by
// ensuring that the variant described in the attribute has a
// corresponding definition or declaration of the vector
@@ -276,7 +276,7 @@ class VFDatabase {
/// @{
/// Retrieve the Function with VFShape \p Shape.
Function *getVectorizedFunction(const VFShape &Shape) const {
- if (Shape == VFShape::getScalarShape(CI))
+ if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
return CI.getCalledFunction();
for (const auto &Info : ScalarToVectorMappings)
diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp
index 88f61cfeb9ba4..92af314a41caa 100644
--- a/llvm/lib/Analysis/VFABIDemangling.cpp
+++ b/llvm/lib/Analysis/VFABIDemangling.cpp
@@ -369,7 +369,7 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
// Format of the ABI name:
// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
- const CallInst &CI) {
+ const FunctionType *FTy) {
const StringRef OriginalName = MangledName;
// Assume there is no custom name <redirection>, and therefore the
// vector name consists of
@@ -434,7 +434,7 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
// demangled parameter types and the scalar function signature.
std::optional<ElementCount> EC;
if (ParsedVF.second) {
- EC = getScalableECFromSignature(CI.getFunctionType(), ISA, Parameters);
+ EC = getScalableECFromSignature(FTy, ISA, Parameters);
if (!EC)
return std::nullopt;
} else
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 96f39ff7e409e..91d8c31fa062d 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1466,7 +1466,8 @@ void VFABI::getVectorVariantNames(
S.split(ListAttr, ",");
for (const auto &S : SetVector<StringRef>(ListAttr.begin(), ListAttr.end())) {
- std::optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, CI);
+ std::optional<VFInfo> Info =
+ VFABI::tryDemangleForVFABI(S, CI.getFunctionType());
if (Info && CI.getModule()->getFunction(Info->VectorName)) {
LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "' for " << CI
<< "\n");
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index fc42df75875e1..7de0959ca57ef 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -346,7 +346,8 @@ void VFABI::setVectorVariantNames(CallInst *CI,
#ifndef NDEBUG
for (const std::string &VariantMapping : VariantMappings) {
LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n");
- std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *CI);
+ std::optional<VFInfo> VI =
+ VFABI::tryDemangleForVFABI(VariantMapping, CI->getFunctionType());
assert(VI && "Cannot add an invalid VFABI name.");
assert(M->getNamedValue(VI->VectorName) &&
"Cannot add variant to attribute: "
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index be48465b8e0e4..21c83fcdac193 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -5466,7 +5466,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
VFShape Shape = VFShape::get(
- *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
+ CI->getFunctionType(),
+ ElementCount::getFixed(static_cast<unsigned int>(VL.size())),
false /*HasGlobalPred*/);
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
@@ -6461,9 +6462,10 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
auto IntrinsicCost =
TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
- auto Shape = VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>(
- VecTy->getNumElements())),
- false /*HasGlobalPred*/);
+ auto Shape = VFShape::get(
+ CI->getFunctionType(),
+ ElementCount::getFixed(static_cast<unsigned>(VecTy->getNumElements())),
+ false /*HasGlobalPred*/);
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
auto LibCost = IntrinsicCost;
if (!CI->isNoBuiltin() && VecFunc) {
@@ -11643,8 +11645,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
Function *CF;
if (!UseIntrinsic) {
VFShape Shape =
- VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>(
- VecTy->getNumElements())),
+ VFShape::get(CI->getFunctionType(),
+ ElementCount::getFixed(
+ static_cast<unsigned>(VecTy->getNumElements())),
false /*HasGlobalPred*/);
CF = VFDatabase(*CI).getVectorizedFunction(Shape);
} else {
diff --git a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
index 09dc15c9e3666..a6ca3bc4484e3 100644
--- a/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
+++ b/llvm/tools/vfabi-demangle-fuzzer/vfabi-demangler-fuzzer.cpp
@@ -31,14 +31,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, size_t Size) {
if (!MangledName.empty() && MangledName.find_first_of(0) == StringRef::npos) {
FunctionType *FTy =
FunctionType::get(Type::getVoidTy(M->getContext()), false);
- FunctionCallee F = M->getOrInsertFunction(MangledName, FTy);
- // Fake the arguments to the CallInst.
- SmallVector<Value *> Args;
- for (Type *ParamTy : FTy->params()) {
- Args.push_back(Constant::getNullValue(ParamTy));
- }
- std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
- const auto Info = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
+ const auto Info = VFABI::tryDemangleForVFABI(MangledName, FTy);
// Do not optimize away the return value. Inspired by
// https://github.com/google/benchmark/blob/main/include/benchmark/benchmark.h#L307-L345
diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
index e496d87c06de6..81c1807cdcaa8 100644
--- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp
+++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp
@@ -82,7 +82,8 @@ class VFABIParserTest : public ::testing::Test {
Args.push_back(Constant::getNullValue(ParamTy->getScalarType()));
}
std::unique_ptr<CallInst> CI(CallInst::Create(F, Args));
- const auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, *(CI.get()));
+ const auto OptInfo =
+ VFABI::tryDemangleForVFABI(MangledName, CI->getFunctionType());
if (OptInfo) {
Info = *OptInfo;
return true;
diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index c7419e0321235..1b3a8b0259f01 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -580,7 +580,7 @@ class VFShapeAPITest : public testing::Test {
SmallVector<VFParameter, 8> &ExpectedParams = Expected.Parameters;
void buildShape(ElementCount VF, bool HasGlobalPred) {
- Shape = VFShape::get(*CI, VF, HasGlobalPred);
+ Shape = VFShape::get(CI->getFunctionType(), VF, HasGlobalPred);
}
bool validParams(ArrayRef<VFParameter> Parameters) {
@@ -619,11 +619,11 @@ TEST_F(VFShapeAPITest, API_buildVFShape) {
TEST_F(VFShapeAPITest, API_getScalarShape) {
buildShape(/*VF*/ ElementCount::getFixed(1), /*HasGlobalPred*/ false);
- EXPECT_EQ(VFShape::getScalarShape(*CI), Shape);
+ EXPECT_EQ(VFShape::getScalarShape(CI->getFunctionType()), Shape);
}
TEST_F(VFShapeAPITest, API_getVectorizedFunction) {
- VFShape ScalarShape = VFShape::getScalarShape(*CI);
+ VFShape ScalarShape = VFShape::getScalarShape(CI->getFunctionType());
EXPECT_EQ(VFDatabase(*CI).getVectorizedFunction(ScalarShape),
M->getFunction("g"));
More information about the llvm-commits
mailing list