[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)
Vedant Paranjape via llvm-commits
llvm-commits at lists.llvm.org
Fri May 31 05:58:01 PDT 2024
https://github.com/vedantparanjape-amd updated https://github.com/llvm/llvm-project/pull/78735
>From b8b89bb93c8e253a9577e2beb204f293479ea1b6 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Thu, 18 Jan 2024 19:52:06 +0000
Subject: [PATCH] [ArgPromotion] Handle pointer arguments of recursive calls
---
llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 60 +++++++++++++++++--
1 file changed, 56 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 3aa8ea3f51471..05211adf9f1a1 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -445,7 +445,7 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
/// Determine that this argument is safe to promote, and find the argument
/// parts it can be promoted into.
static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
- unsigned MaxElements, bool IsRecursive,
+ unsigned MaxElements, bool IsRecursive, bool IsSelfRecursive,
SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
// Quick exit for unused arguments
if (Arg->use_empty())
@@ -610,13 +610,59 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
// unknown users
}
+ if (IsSelfRecursive && isa<CallBase>(V)) {
+ Value *Ptr = dyn_cast<Value>(U);
+ Type *PtrTy = Ptr->getType();
+ Align PtrAlign = Ptr->getPointerAlignment(DL);
+ APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
+ Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
+ /* AllowNonInbounds */ true);
+ if (Ptr != Arg)
+ return false;
+
+ if (Offset.getSignificantBits() >= 64)
+ return false;
+
+ int64_t Off = Offset.getSExtValue();
+ auto Pair = ArgParts.try_emplace(Off, ArgPart{PtrTy, PtrAlign, nullptr});
+ ArgPart &Part = Pair.first->second;
+
+ // We limit promotion to only promoting up to a fixed number of elements
+ // of the aggregate.
+ if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+ LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+ << "more than " << MaxElements << " parts\n");
+ return false;
+ }
+
+ Part.Alignment = std::max(Part.Alignment, PtrAlign);
+ continue;
+ }
// Unknown user.
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "unknown user " << *V << "\n");
return false;
}
- if (NeededDerefBytes || NeededAlign > 1) {
+ // Incase of functions with recursive calls, this check will fail when it
+ // tries to look at the first caller of this function. The caller may or may
+ // not have a load, incase it doesn't load the pointer being passed, this
+ // check will fail. So, it's safe to skip the check incase we know that we
+ // are dealing with a recursive call.
+ //
+ // def fun(ptr %a) {
+ // ...
+ // %loadres = load i32, ptr %a, align 4
+ // %res = call i32 @fun(ptr %a)
+ // ...
+ // }
+ //
+ // def bar(ptr %x) {
+ // ...
+ // %resbar = call i32 @fun(ptr %x)
+ // ...
+ // }
+ if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
// Try to prove a required deref / aligned requirement.
if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
NeededDerefBytes)) {
@@ -699,6 +745,10 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
/// calls the DoPromotion method.
static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
unsigned MaxElements, bool IsRecursive) {
+ // Due to complexity of handling cases where the SCC has more than one
+ // component. We want to limit argument promotion of recursive calls to
+ // just functions that directly call themselves.
+ bool IsSelfRecursive = false;
// Don't perform argument promotion for naked functions; otherwise we can end
// up removing parameters that are seemingly 'not used' as they are referred
// to in the assembly.
@@ -744,8 +794,10 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
if (CB->isMustTailCall())
return nullptr;
- if (CB->getFunction() == F)
+ if (CB->getFunction() == F) {
IsRecursive = true;
+ IsSelfRecursive = true;
+ }
}
// Can't change signature of musttail caller
@@ -779,7 +831,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
// If we can promote the pointer to its value.
SmallVector<OffsetAndArgPart, 4> ArgParts;
- if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
+ if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, IsSelfRecursive, ArgParts)) {
SmallVector<Type *, 4> Types;
for (const auto &Pair : ArgParts)
Types.push_back(Pair.second.Ty);
More information about the llvm-commits
mailing list