[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 19 07:53:57 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Vedant Paranjape (vedantparanjape-amd)
<details>
<summary>Changes</summary>
Tries to fix llvm#<!-- -->1259. This implementation fails pr42028-recursion.ll, the compiler crashes as it is not able to handle call graph as follows:
```
A -> B -> C
↑____|
```
---
Full diff: https://github.com/llvm/llvm-project/pull/78735.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/IPO/ArgumentPromotion.cpp (+54-1)
``````````diff
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 8058282c422503..072a8c46f1e54c 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -609,13 +609,66 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
// unknown users
}
+ auto *CI = dyn_cast<CallInst>(V);
+ if (IsRecursive && CI && CI->getFunction() == Arg->getParent()) {
+ dbgs() << "Found recursive call\n";
+ Type *Ty = CI->getType();
+ Value *Ptr = CI->getArgOperand(U->getOperandNo());
+ Align PtrAlign = Ptr->getPointerAlignment(DL);
+ APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
+ Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset, /* AllowNonInbounds */ true);
+ if (Ptr != Arg)
+ return false;
+
+ if (Offset.getSignificantBits() >= 64)
+ return false;
+
+ TypeSize Size = DL.getTypeStoreSize(Ty);
+ // Don't try to promote scalable types.
+ if (Size.isScalable())
+ return false;
+
+ int64_t Off = Offset.getSExtValue();
+ auto Pair = ArgParts.try_emplace(
+ Off, ArgPart{Ty, PtrAlign, nullptr});
+ ArgPart &Part = Pair.first->second;
+
+ // We limit promotion to only promoting up to a fixed number of elements of
+ // the aggregate.
+ if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+ LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+ << "more than " << MaxElements << " parts\n");
+ return false;
+ }
+
+ Part.Alignment = std::max(Part.Alignment, PtrAlign);
+ continue;
+ }
// Unknown user.
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "unknown user " << *V << "\n");
return false;
}
- if (NeededDerefBytes || NeededAlign > 1) {
+ // Incase of functions with recursive calls, this check will fail when it
+ // tries to look at the first caller of this function. The caller may or may
+ // not have a load, incase it doesn't load the pointer being passed, this
+ // check will fail. So, it's safe to skip the check incase we know that we
+ // are dealing with a recursive call.
+ //
+ // def fun(ptr %a) {
+ // ...
+ // %loadres = load i32, ptr %a, align 4
+ // %res = call i32 @fun(ptr %a)
+ // ...
+ // }
+ //
+ // def bar(ptr %x) {
+ // ...
+ // %resbar = call i32 @fun(ptr %x)
+ // ...
+ // }
+ if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
// Try to prove a required deref / aligned requirement.
if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
NeededDerefBytes)) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/78735
More information about the llvm-commits
mailing list