[llvm] [ArgPromotion] Handle pointer arguments of recursive calls (PR #78735)
Vedant Paranjape via llvm-commits
llvm-commits at lists.llvm.org
Fri May 31 06:53:46 PDT 2024
https://github.com/vedantparanjape-amd updated https://github.com/llvm/llvm-project/pull/78735
>From fbc9cad366bef315890e9c83f2ba31055f6b17ee Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Thu, 18 Jan 2024 19:52:06 +0000
Subject: [PATCH] [ArgPromotion] Handle pointer arguments of recursive calls
Argument promotion doesn't handle recursive function calls to promote
arguments. This patch adds functionality to handle self recursive
function calls, i.e. whose SCC size is 1. Due to complexity of Value
Tracking in recursive calls with SCC size greater than 1, we bail out
in such cases.
---
llvm/lib/Transforms/IPO/ArgumentPromotion.cpp | 62 +++++++++++++++++-
.../argpromotion-recursion-pr1259.ll | 65 +++++++++++++++++++
2 files changed, 124 insertions(+), 3 deletions(-)
create mode 100644 llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 3aa8ea3f51471..8f5d0ea4af397 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -446,6 +446,7 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
/// parts it can be promoted into.
static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
unsigned MaxElements, bool IsRecursive,
+ bool IsSelfRecursive,
SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
// Quick exit for unused arguments
if (Arg->use_empty())
@@ -610,13 +611,61 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
// unknown users
}
+ auto *CB = dyn_cast<CallBase>(V);
+ Value *PtrArg = dyn_cast<Value>(U);
+ if (IsSelfRecursive && CB && PtrArg) {
+ Type *PtrTy = PtrArg->getType();
+ Align PtrAlign = PtrArg->getPointerAlignment(DL);
+ APInt Offset(DL.getIndexTypeSizeInBits(PtrArg->getType()), 0);
+ PtrArg = PtrArg->stripAndAccumulateConstantOffsets(
+ DL, Offset,
+ /* AllowNonInbounds= */ true);
+ if (PtrArg != Arg)
+ return false;
+
+ if (Offset.getSignificantBits() >= 64)
+ return false;
+
+ int64_t Off = Offset.getSExtValue();
+ auto Pair = ArgParts.try_emplace(Off, ArgPart{PtrTy, PtrAlign, nullptr});
+ ArgPart &Part = Pair.first->second;
+
+ // We limit promotion to only promoting up to a fixed number of elements
+ // of the aggregate.
+ if (MaxElements > 0 && ArgParts.size() > MaxElements) {
+ LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
+ << "more than " << MaxElements << " parts\n");
+ return false;
+ }
+
+ Part.Alignment = std::max(Part.Alignment, PtrAlign);
+ continue;
+ }
// Unknown user.
LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
<< "unknown user " << *V << "\n");
return false;
}
- if (NeededDerefBytes || NeededAlign > 1) {
+ // Incase of functions with recursive calls, this check will fail when it
+ // tries to look at the first caller of this function. The caller may or may
+ // not have a load, incase it doesn't load the pointer being passed, this
+ // check will fail. So, it's safe to skip the check incase we know that we
+ // are dealing with a recursive call.
+ //
+ // def fun(ptr %a) {
+ // ...
+ // %loadres = load i32, ptr %a, align 4
+ // %res = call i32 @fun(ptr %a)
+ // ...
+ // }
+ //
+ // def bar(ptr %x) {
+ // ...
+ // %resbar = call i32 @fun(ptr %x)
+ // ...
+ // }
+ if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1)) {
// Try to prove a required deref / aligned requirement.
if (!allCallersPassValidPointerForArgument(Arg, NeededAlign,
NeededDerefBytes)) {
@@ -699,6 +748,10 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
/// calls the DoPromotion method.
static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
unsigned MaxElements, bool IsRecursive) {
+ // Due to complexity of handling cases where the SCC has more than one
+ // component. We want to limit argument promotion of recursive calls to
+ // just functions that directly call themselves.
+ bool IsSelfRecursive = false;
// Don't perform argument promotion for naked functions; otherwise we can end
// up removing parameters that are seemingly 'not used' as they are referred
// to in the assembly.
@@ -744,8 +797,10 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
if (CB->isMustTailCall())
return nullptr;
- if (CB->getFunction() == F)
+ if (CB->getFunction() == F) {
IsRecursive = true;
+ IsSelfRecursive = true;
+ }
}
// Can't change signature of musttail caller
@@ -779,7 +834,8 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
// If we can promote the pointer to its value.
SmallVector<OffsetAndArgPart, 4> ArgParts;
- if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
+ if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, IsSelfRecursive,
+ ArgParts)) {
SmallVector<Type *, 4> Types;
for (const auto &Pair : ArgParts)
Types.push_back(Pair.second.Ty);
diff --git a/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll b/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
new file mode 100644
index 0000000000000..19bb4492171fc
--- /dev/null
+++ b/llvm/test/Transforms/ArgumentPromotion/argpromotion-recursion-pr1259.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
+define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define internal i32 @foo(
+; CHECK-SAME: i32 [[X_0_VAL:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[N]], 0
+; CHECK-NEXT: br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
+; CHECK: [[COND_TRUE]]:
+; CHECK-NEXT: br label %[[RETURN:.*]]
+; CHECK: [[COND_FALSE]]:
+; CHECK-NEXT: [[SUBVAL:%.*]] = sub i32 [[N]], 1
+; CHECK-NEXT: [[CALLRET:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL]], i32 [[X_0_VAL]])
+; CHECK-NEXT: [[SUBVAL2:%.*]] = sub i32 [[N]], 2
+; CHECK-NEXT: [[CALLRET2:%.*]] = call i32 @foo(i32 [[X_0_VAL]], i32 [[SUBVAL2]], i32 [[M]])
+; CHECK-NEXT: [[CMP2:%.*]] = add i32 [[CALLRET]], [[CALLRET2]]
+; CHECK-NEXT: br label %[[RETURN]]
+; CHECK: [[COND_NEXT:.*]]:
+; CHECK-NEXT: br label %[[RETURN]]
+; CHECK: [[RETURN]]:
+; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[X_0_VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ undef, %[[COND_NEXT]] ]
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+entry:
+ %cmp = icmp ne i32 %n, 0
+ br i1 %cmp, label %cond_true, label %cond_false
+
+cond_true: ; preds = %entry
+ %val = load i32, ptr %x, align 4
+ br label %return
+
+cond_false: ; preds = %entry
+ %val2 = load i32, ptr %x, align 4
+ %subval = sub i32 %n, 1
+ %callret = call i32 @foo(ptr %x, i32 %subval, i32 %val2)
+ %subval2 = sub i32 %n, 2
+ %callret2 = call i32 @foo(ptr %x, i32 %subval2, i32 %m)
+ %cmp2 = add i32 %callret, %callret2
+ br label %return
+
+cond_next: ; No predecessors!
+ br label %return
+
+return: ; preds = %cond_next, %cond_false, %cond_true
+ %retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ undef, %cond_next ]
+ ret i32 %retval.0
+}
+
+define i32 @bar(ptr %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define i32 @bar(
+; CHECK-SAME: ptr [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[X_VAL:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT: [[CALLRET3:%.*]] = call i32 @foo(i32 [[X_VAL]], i32 [[N]], i32 [[M]])
+; CHECK-NEXT: br label %[[RETURN:.*]]
+; CHECK: [[RETURN]]:
+; CHECK-NEXT: ret i32 [[CALLRET3]]
+;
+entry:
+ %callret3 = call i32 @foo(ptr %x, i32 %n, i32 %m)
+ br label %return
+
+return: ; preds = %entry
+ ret i32 %callret3
+}
More information about the llvm-commits
mailing list