[llvm] [Attributor] Keep track of reached returns in AAPointerInfo (PR #107479)
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 5 16:02:28 PDT 2024
https://github.com/jdoerfert updated https://github.com/llvm/llvm-project/pull/107479
>From a211d88a21543cc9ed11aebab3a27bc73b7454fb Mon Sep 17 00:00:00 2001
From: Johannes Doerfert <johannes at jdoerfert.de>
Date: Thu, 8 Aug 2024 14:01:43 -0700
Subject: [PATCH] [Attributor] Keep track of reached returns in AAPointerInfo
Instead of visiting call sites in Attribute::checkForAllUses, we now
keep track of returns in AAPointerInfo and use the call site return
information as required. This way, the user of
AAPointerInfo(CallSite)Argument can determine if the call return should
be visited. We do not collect them as "may accesses" in the
AAPointerInfo(CallSite)Argument itself in case a return user is found.
---
llvm/include/llvm/Transforms/IPO/Attributor.h | 1 +
llvm/lib/Transforms/IPO/Attributor.cpp | 16 -----
.../Transforms/IPO/AttributorAttributes.cpp | 59 +++++++++++++++----
.../Attributor/IPConstantProp/pthreads.ll | 8 +--
.../Attributor/value-simplify-pointer-info.ll | 15 +----
5 files changed, 57 insertions(+), 42 deletions(-)
diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 5844fb8b0f8938..6ab63ba582c546 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -6119,6 +6119,7 @@ struct AAPointerInfo : public AbstractAttribute {
virtual const_bin_iterator begin() const = 0;
virtual const_bin_iterator end() const = 0;
virtual int64_t numOffsetBins() const = 0;
+ virtual bool reachesReturn() const = 0;
/// Call \p CB on all accesses that might interfere with \p Range and return
/// true if all such accesses were known and the callback returned true for
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 38b61b6a88357c..56d1133b25549a 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -1852,22 +1852,6 @@ bool Attributor::checkForAllUses(
User &Usr = *U->getUser();
AddUsers(Usr, /* OldUse */ nullptr);
-
- auto *RI = dyn_cast<ReturnInst>(&Usr);
- if (!RI)
- continue;
-
- Function &F = *RI->getFunction();
- auto CallSitePred = [&](AbstractCallSite ACS) {
- return AddUsers(*ACS.getInstruction(), U);
- };
- if (!checkForAllCallSites(CallSitePred, F, /* RequireAllCallSites */ true,
- &QueryingAA, UsedAssumedInformation)) {
- LLVM_DEBUG(dbgs() << "[Attributor] Could not follow return instruction "
- "to all call sites: "
- << *RI << "\n");
- return false;
- }
}
return true;
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 1fe8e6515fe0e8..899ec3d264e541 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -827,6 +827,7 @@ struct AA::PointerInfo::State : public AbstractState {
AccessList = R.AccessList;
OffsetBins = R.OffsetBins;
RemoteIMap = R.RemoteIMap;
+ ReachesReturn = R.ReachesReturn;
return *this;
}
@@ -837,6 +838,7 @@ struct AA::PointerInfo::State : public AbstractState {
std::swap(AccessList, R.AccessList);
std::swap(OffsetBins, R.OffsetBins);
std::swap(RemoteIMap, R.RemoteIMap);
+ std::swap(ReachesReturn, R.ReachesReturn);
return *this;
}
@@ -878,11 +880,16 @@ struct AA::PointerInfo::State : public AbstractState {
AAPointerInfo::OffsetBinsTy OffsetBins;
DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap;
+ /// Flag to determine if the underlying pointer is reaching a return statement
+ /// in the associated function or not. Returns in other functions cause
+ /// invalidation.
+ bool ReachesReturn = false;
+
/// See AAPointerInfo::forallInterferingAccesses.
bool forallInterferingAccesses(
AA::RangeTy Range,
function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const {
- if (!isValidState())
+ if (!isValidState() || ReachesReturn)
return false;
for (const auto &It : OffsetBins) {
@@ -904,7 +911,7 @@ struct AA::PointerInfo::State : public AbstractState {
Instruction &I,
function_ref<bool(const AAPointerInfo::Access &, bool)> CB,
AA::RangeTy &Range) const {
- if (!isValidState())
+ if (!isValidState() || ReachesReturn)
return false;
auto LocalList = RemoteIMap.find(&I);
@@ -1071,7 +1078,8 @@ struct AAPointerInfoImpl
return std::string("PointerInfo ") +
(isValidState() ? (std::string("#") +
std::to_string(OffsetBins.size()) + " bins")
- : "<invalid>");
+ : "<invalid>") +
+ (ReachesReturn ? " (returned)" : "");
}
/// See AbstractAttribute::manifest(...).
@@ -1084,6 +1092,7 @@ struct AAPointerInfoImpl
virtual int64_t numOffsetBins() const override {
return State::numOffsetBins();
}
+ virtual bool reachesReturn() const override { return ReachesReturn; }
bool forallInterferingAccesses(
AA::RangeTy Range,
@@ -1373,6 +1382,7 @@ struct AAPointerInfoImpl
const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA);
bool IsByval = OtherAAImpl.getAssociatedArgument()->hasByValAttr();
+ ReachesReturn = OtherAAImpl.ReachesReturn;
// Combine the accesses bin by bin.
ChangeStatus Changed = ChangeStatus::UNCHANGED;
@@ -1666,8 +1676,13 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
}
if (isa<PtrToIntInst>(Usr))
return false;
- if (isa<CastInst>(Usr) || isa<SelectInst>(Usr) || isa<ReturnInst>(Usr))
+ if (isa<CastInst>(Usr) || isa<SelectInst>(Usr))
return HandlePassthroughUser(Usr, CurPtr, Follow);
+ // Returns are allowed if they are in the associated functions. Users can
+ // then check the call site return. Returns from other functions can't be
+ // tracked and are cause for invalidation.
+ if (auto *RI = dyn_cast<ReturnInst>(Usr))
+ return ReachesReturn = RI->getFunction() == getAssociatedFunction();
// For PHIs we need to take care of the recurrence explicitly as the value
// might change while we iterate through a loop. For now, we give up if
@@ -1898,15 +1913,37 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
DepClassTy::REQUIRED);
if (!CSArgPI)
return false;
- bool IsMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
+ bool IsArgMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
Changed = translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB,
- IsMustAcc) |
+ IsArgMustAcc) |
+ Changed;
+ if (!CSArgPI->reachesReturn())
+ return isValidState();
+
+ Function *Callee = CB->getCalledFunction();
+ if (!Callee || Callee->arg_size() <= ArgNo)
+ return false;
+ bool UsedAssumedInformation = false;
+ auto ReturnedValue = A.getAssumedSimplified(
+ IRPosition::returned(*Callee), *this, UsedAssumedInformation,
+ AA::ValueScope::Intraprocedural);
+ auto *ReturnedArg =
+ dyn_cast_or_null<Argument>(ReturnedValue.value_or(nullptr));
+ auto *Arg = Callee->getArg(ArgNo);
+ if (ReturnedArg && Arg != ReturnedArg)
+ return true;
+ bool IsRetMustAcc = IsArgMustAcc && (ReturnedArg == Arg);
+ const auto *CSRetPI = A.getAAFor<AAPointerInfo>(
+ *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
+ if (!CSRetPI)
+ return false;
+ Changed = translateAndAddState(A, *CSRetPI, OffsetInfoMap[CurPtr], *CB,
+ IsRetMustAcc) |
Changed;
return isValidState();
}
LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB
<< "\n");
- // TODO: Allow some call uses
return false;
}
@@ -2342,8 +2379,10 @@ struct AANoFreeFloating : AANoFreeImpl {
Follow = true;
return true;
}
- if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI) ||
- isa<ReturnInst>(UserI))
+ if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI))
+ return true;
+
+ if (isa<ReturnInst>(UserI) && getIRPosition().isArgumentPosition())
return true;
// Unknown user.
@@ -12717,7 +12756,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo {
if (!PI)
return indicatePessimisticFixpoint();
- if (!PI->getState().isValidState())
+ if (!PI->getState().isValidState() || PI->reachesReturn())
return indicatePessimisticFixpoint();
const DataLayout &DL = A.getDataLayout();
diff --git a/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll b/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
index 490894d1290231..01a97821140ec6 100644
--- a/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
+++ b/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
@@ -34,13 +34,13 @@ target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
define dso_local i32 @main() {
; TUNIT-LABEL: define {{[^@]+}}@main() {
; TUNIT-NEXT: entry:
-; TUNIT-NEXT: [[ALLOC11:%.*]] = alloca i8, i32 0, align 8
-; TUNIT-NEXT: [[ALLOC22:%.*]] = alloca i8, i32 0, align 8
+; TUNIT-NEXT: [[ALLOC1:%.*]] = alloca i8, align 8
+; TUNIT-NEXT: [[ALLOC2:%.*]] = alloca i8, align 8
; TUNIT-NEXT: [[THREAD:%.*]] = alloca i64, align 8
; TUNIT-NEXT: [[CALL:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @foo, ptr nofree readnone align 4294967296 undef)
; TUNIT-NEXT: [[CALL1:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @bar, ptr noalias nocapture nofree nonnull readnone align 8 dereferenceable(8) undef)
-; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC11]])
-; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC22]])
+; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC1]])
+; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC2]])
; TUNIT-NEXT: ret i32 0
;
; CGSCC-LABEL: define {{[^@]+}}@main() {
diff --git a/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll b/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
index 69bff7b5e783ea..378560cc89cd12 100644
--- a/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
+++ b/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
@@ -3185,10 +3185,7 @@ define i32 @may_access_after_return(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
-; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
-; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
-; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
-; TUNIT-NEXT: ret i32 [[ADD]]
+; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return
@@ -3304,10 +3301,7 @@ define i32 @may_access_after_return_no_choice1(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
-; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
-; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
-; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
-; TUNIT-NEXT: ret i32 [[ADD]]
+; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice1
@@ -3342,10 +3336,7 @@ define i32 @may_access_after_return_no_choice2(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]]) #[[ATTR18]]
-; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
-; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
-; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
-; TUNIT-NEXT: ret i32 [[ADD]]
+; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice2
More information about the llvm-commits
mailing list