[llvm] 56a0334 - [Attributor] Keep track of reached returns in AAPointerInfo (#107479)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 08:13:24 PDT 2024


Author: Johannes Doerfert
Date: 2024-09-10T08:13:21-07:00
New Revision: 56a033462ed28083042e8a99f3a0fb16c1933ba4

URL: https://github.com/llvm/llvm-project/commit/56a033462ed28083042e8a99f3a0fb16c1933ba4
DIFF: https://github.com/llvm/llvm-project/commit/56a033462ed28083042e8a99f3a0fb16c1933ba4.diff

LOG: [Attributor] Keep track of reached returns in AAPointerInfo (#107479)

Instead of visiting call sites in Attribute::checkForAllUses, we now
keep track of returns in AAPointerInfo and use the call site return
information as required. This way, the user of
AAPointerInfo(CallSite)Argument can determine if the call return should
be visited. We do not collect them as "may accesses" in the
AAPointerInfo(CallSite)Argument itself in case a return user is found.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/IPO/Attributor.h
    llvm/lib/Transforms/IPO/Attributor.cpp
    llvm/lib/Transforms/IPO/AttributorAttributes.cpp
    llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
    llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 5844fb8b0f8938..6ab63ba582c546 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -6119,6 +6119,7 @@ struct AAPointerInfo : public AbstractAttribute {
   virtual const_bin_iterator begin() const = 0;
   virtual const_bin_iterator end() const = 0;
   virtual int64_t numOffsetBins() const = 0;
+  virtual bool reachesReturn() const = 0;
 
   /// Call \p CB on all accesses that might interfere with \p Range and return
   /// true if all such accesses were known and the callback returned true for

diff  --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 38b61b6a88357c..56d1133b25549a 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -1852,22 +1852,6 @@ bool Attributor::checkForAllUses(
 
     User &Usr = *U->getUser();
     AddUsers(Usr, /* OldUse */ nullptr);
-
-    auto *RI = dyn_cast<ReturnInst>(&Usr);
-    if (!RI)
-      continue;
-
-    Function &F = *RI->getFunction();
-    auto CallSitePred = [&](AbstractCallSite ACS) {
-      return AddUsers(*ACS.getInstruction(), U);
-    };
-    if (!checkForAllCallSites(CallSitePred, F, /* RequireAllCallSites */ true,
-                              &QueryingAA, UsedAssumedInformation)) {
-      LLVM_DEBUG(dbgs() << "[Attributor] Could not follow return instruction "
-                           "to all call sites: "
-                        << *RI << "\n");
-      return false;
-    }
   }
 
   return true;

diff  --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 6fbb16e250f878..b73f5265f87874 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -827,6 +827,7 @@ struct AA::PointerInfo::State : public AbstractState {
     AccessList = R.AccessList;
     OffsetBins = R.OffsetBins;
     RemoteIMap = R.RemoteIMap;
+    ReachesReturn = R.ReachesReturn;
     return *this;
   }
 
@@ -837,6 +838,7 @@ struct AA::PointerInfo::State : public AbstractState {
     std::swap(AccessList, R.AccessList);
     std::swap(OffsetBins, R.OffsetBins);
     std::swap(RemoteIMap, R.RemoteIMap);
+    std::swap(ReachesReturn, R.ReachesReturn);
     return *this;
   }
 
@@ -878,11 +880,16 @@ struct AA::PointerInfo::State : public AbstractState {
   AAPointerInfo::OffsetBinsTy OffsetBins;
   DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap;
 
+  /// Flag to determine if the underlying pointer is reaching a return statement
+  /// in the associated function or not. Returns in other functions cause
+  /// invalidation.
+  bool ReachesReturn = false;
+
   /// See AAPointerInfo::forallInterferingAccesses.
   bool forallInterferingAccesses(
       AA::RangeTy Range,
       function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const {
-    if (!isValidState())
+    if (!isValidState() || ReachesReturn)
       return false;
 
     for (const auto &It : OffsetBins) {
@@ -904,7 +911,7 @@ struct AA::PointerInfo::State : public AbstractState {
       Instruction &I,
       function_ref<bool(const AAPointerInfo::Access &, bool)> CB,
       AA::RangeTy &Range) const {
-    if (!isValidState())
+    if (!isValidState() || ReachesReturn)
       return false;
 
     auto LocalList = RemoteIMap.find(&I);
@@ -1071,7 +1078,8 @@ struct AAPointerInfoImpl
     return std::string("PointerInfo ") +
            (isValidState() ? (std::string("#") +
                               std::to_string(OffsetBins.size()) + " bins")
-                           : "<invalid>");
+                           : "<invalid>") +
+           (ReachesReturn ? " (returned)" : "");
   }
 
   /// See AbstractAttribute::manifest(...).
@@ -1084,6 +1092,7 @@ struct AAPointerInfoImpl
   virtual int64_t numOffsetBins() const override {
     return State::numOffsetBins();
   }
+  virtual bool reachesReturn() const override { return ReachesReturn; }
 
   bool forallInterferingAccesses(
       AA::RangeTy Range,
@@ -1373,6 +1382,7 @@ struct AAPointerInfoImpl
 
     const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA);
     bool IsByval = OtherAAImpl.getAssociatedArgument()->hasByValAttr();
+    ReachesReturn = OtherAAImpl.ReachesReturn;
 
     // Combine the accesses bin by bin.
     ChangeStatus Changed = ChangeStatus::UNCHANGED;
@@ -1666,8 +1676,13 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
     }
     if (isa<PtrToIntInst>(Usr))
       return false;
-    if (isa<CastInst>(Usr) || isa<SelectInst>(Usr) || isa<ReturnInst>(Usr))
+    if (isa<CastInst>(Usr) || isa<SelectInst>(Usr))
       return HandlePassthroughUser(Usr, CurPtr, Follow);
+    // Returns are allowed if they are in the associated functions. Users can
+    // then check the call site return. Returns from other functions can't be
+    // tracked and are cause for invalidation.
+    if (auto *RI = dyn_cast<ReturnInst>(Usr))
+      return ReachesReturn = RI->getFunction() == getAssociatedFunction();
 
     // For PHIs we need to take care of the recurrence explicitly as the value
     // might change while we iterate through a loop. For now, we give up if
@@ -1898,15 +1913,37 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
             DepClassTy::REQUIRED);
         if (!CSArgPI)
           return false;
-        bool IsMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
+        bool IsArgMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
         Changed = translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB,
-                                       IsMustAcc) |
+                                       IsArgMustAcc) |
+                  Changed;
+        if (!CSArgPI->reachesReturn())
+          return isValidState();
+
+        Function *Callee = CB->getCalledFunction();
+        if (!Callee || Callee->arg_size() <= ArgNo)
+          return false;
+        bool UsedAssumedInformation = false;
+        auto ReturnedValue = A.getAssumedSimplified(
+            IRPosition::returned(*Callee), *this, UsedAssumedInformation,
+            AA::ValueScope::Intraprocedural);
+        auto *ReturnedArg =
+            dyn_cast_or_null<Argument>(ReturnedValue.value_or(nullptr));
+        auto *Arg = Callee->getArg(ArgNo);
+        if (ReturnedArg && Arg != ReturnedArg)
+          return true;
+        bool IsRetMustAcc = IsArgMustAcc && (ReturnedArg == Arg);
+        const auto *CSRetPI = A.getAAFor<AAPointerInfo>(
+            *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
+        if (!CSRetPI)
+          return false;
+        Changed = translateAndAddState(A, *CSRetPI, OffsetInfoMap[CurPtr], *CB,
+                                       IsRetMustAcc) |
                   Changed;
         return isValidState();
       }
       LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB
                         << "\n");
-      // TODO: Allow some call uses
       return false;
     }
 
@@ -2342,8 +2379,10 @@ struct AANoFreeFloating : AANoFreeImpl {
         Follow = true;
         return true;
       }
-      if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI) ||
-          isa<ReturnInst>(UserI))
+      if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI))
+        return true;
+
+      if (isa<ReturnInst>(UserI) && getIRPosition().isArgumentPosition())
         return true;
 
       // Unknown user.
@@ -12740,7 +12779,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo {
     if (!PI)
       return indicatePessimisticFixpoint();
 
-    if (!PI->getState().isValidState())
+    if (!PI->getState().isValidState() || PI->reachesReturn())
       return indicatePessimisticFixpoint();
 
     const DataLayout &DL = A.getDataLayout();

diff  --git a/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll b/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
index 490894d1290231..01a97821140ec6 100644
--- a/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
+++ b/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
@@ -34,13 +34,13 @@ target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
 define dso_local i32 @main() {
 ; TUNIT-LABEL: define {{[^@]+}}@main() {
 ; TUNIT-NEXT:  entry:
-; TUNIT-NEXT:    [[ALLOC11:%.*]] = alloca i8, i32 0, align 8
-; TUNIT-NEXT:    [[ALLOC22:%.*]] = alloca i8, i32 0, align 8
+; TUNIT-NEXT:    [[ALLOC1:%.*]] = alloca i8, align 8
+; TUNIT-NEXT:    [[ALLOC2:%.*]] = alloca i8, align 8
 ; TUNIT-NEXT:    [[THREAD:%.*]] = alloca i64, align 8
 ; TUNIT-NEXT:    [[CALL:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @foo, ptr nofree readnone align 4294967296 undef)
 ; TUNIT-NEXT:    [[CALL1:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @bar, ptr noalias nocapture nofree nonnull readnone align 8 dereferenceable(8) undef)
-; TUNIT-NEXT:    [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC11]])
-; TUNIT-NEXT:    [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC22]])
+; TUNIT-NEXT:    [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC1]])
+; TUNIT-NEXT:    [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC2]])
 ; TUNIT-NEXT:    ret i32 0
 ;
 ; CGSCC-LABEL: define {{[^@]+}}@main() {

diff  --git a/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll b/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
index 69bff7b5e783ea..378560cc89cd12 100644
--- a/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
+++ b/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
@@ -3185,10 +3185,7 @@ define i32 @may_access_after_return(i32 noundef %N, i32 noundef %M) {
 ; TUNIT-NEXT:    [[A:%.*]] = alloca i32, align 4
 ; TUNIT-NEXT:    [[B:%.*]] = alloca i32, align 4
 ; TUNIT-NEXT:    call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
-; TUNIT-NEXT:    [[TMP0:%.*]] = load i32, ptr [[A]], align 4
-; TUNIT-NEXT:    [[TMP1:%.*]] = load i32, ptr [[B]], align 4
-; TUNIT-NEXT:    [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
-; TUNIT-NEXT:    ret i32 [[ADD]]
+; TUNIT-NEXT:    ret i32 8
 ;
 ; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
 ; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return
@@ -3304,10 +3301,7 @@ define i32 @may_access_after_return_no_choice1(i32 noundef %N, i32 noundef %M) {
 ; TUNIT-NEXT:    [[A:%.*]] = alloca i32, align 4
 ; TUNIT-NEXT:    [[B:%.*]] = alloca i32, align 4
 ; TUNIT-NEXT:    call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
-; TUNIT-NEXT:    [[TMP0:%.*]] = load i32, ptr [[A]], align 4
-; TUNIT-NEXT:    [[TMP1:%.*]] = load i32, ptr [[B]], align 4
-; TUNIT-NEXT:    [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
-; TUNIT-NEXT:    ret i32 [[ADD]]
+; TUNIT-NEXT:    ret i32 8
 ;
 ; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
 ; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice1
@@ -3342,10 +3336,7 @@ define i32 @may_access_after_return_no_choice2(i32 noundef %N, i32 noundef %M) {
 ; TUNIT-NEXT:    [[A:%.*]] = alloca i32, align 4
 ; TUNIT-NEXT:    [[B:%.*]] = alloca i32, align 4
 ; TUNIT-NEXT:    call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]]) #[[ATTR18]]
-; TUNIT-NEXT:    [[TMP0:%.*]] = load i32, ptr [[A]], align 4
-; TUNIT-NEXT:    [[TMP1:%.*]] = load i32, ptr [[B]], align 4
-; TUNIT-NEXT:    [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
-; TUNIT-NEXT:    ret i32 [[ADD]]
+; TUNIT-NEXT:    ret i32 8
 ;
 ; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
 ; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice2


        


More information about the llvm-commits mailing list