[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)
Vedant Paranjape via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 19 07:57:18 PST 2024
https://github.com/vedantparanjape-amd updated https://github.com/llvm/llvm-project/pull/78735
>From 9a6065d730833677a4ac574416b2b08e076a790e Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Thu, 18 Jan 2024 19:52:06 +0000
Subject: [PATCH] [ArgPromotion] Handle pointer arguments of recursive calls
---
llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 55 ++++++++++++++++++-
1 file changed, 54 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 8058282c422503..3891afbe019de9 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -609,13 +609,66 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
// unknown users
}
+ auto *CI = dyn_cast<CallInst>(V);
+ if (IsRecursive && CI && CI->getFunction() == Arg->getParent()) {
+ dbgs() << "Found recursive call\n";
+ Type *Ty = CI->getType();
+ Value *Ptr = CI->getArgOperand(U->getOperandNo());
+ Align PtrAlign = Ptr->getPointerAlignment(DL);
+ APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
+ Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
+ /* AllowNonInbounds */ true);
+ if (Ptr != Arg)
+ return false;
+
+ if (Offset.getSignificantBits() >= 64)
+ return false;
+
+ TypeSize Size = DL.getTypeStoreSize(Ty);
+ // Don't try to promote scalable types.
+ if (Size.isScalable())
+ return false;
+
+ int64_t Off = Offset.getSExtValue();
+ auto Pair = ArgParts.try_emplace(Off, ArgPart{Ty, PtrAlign, nullptr});
+ ArgPart &Part = Pair.first->second;
+
+ // We limit promotion to only promoting up to a fixed number of elements
+ // of the aggregate.
+ if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+ LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+ << "more than " << MaxElements << " parts\n");
+ return false;
+ }
+
+ Part.Alignment = std::max(Part.Alignment, PtrAlign);
+ continue;
+ }
// Unknown user.
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "unknown user " << *V << "\n");
return false;
}
- if (NeededDerefBytes || NeededAlign > 1) {
+ // Incase of functions with recursive calls, this check will fail when it
+ // tries to look at the first caller of this function. The caller may or may
+ // not have a load, incase it doesn't load the pointer being passed, this
+ // check will fail. So, it's safe to skip the check incase we know that we
+ // are dealing with a recursive call.
+ //
+ // def fun(ptr %a) {
+ // ...
+ // %loadres = load i32, ptr %a, align 4
+ // %res = call i32 @fun(ptr %a)
+ // ...
+ // }
+ //
+ // def bar(ptr %x) {
+ // ...
+ // %resbar = call i32 @fun(ptr %x)
+ // ...
+ // }
+ if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
// Try to prove a required deref / aligned requirement.
if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
NeededDerefBytes)) {
More information about the llvm-commits
mailing list