[clang] [llvm] [CaptureTracking][FunctionAttrs] Add support for CaptureInfo (PR #125880)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 7 10:52:24 PST 2025


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/125880

>From 2f698e27ae61b91019544cc707c134e0aec9ecd3 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 30 Jan 2025 12:08:01 +0100
Subject: [PATCH 1/8] [CaptureTracking][FunctionAttrs] Add support for
 CaptureInfo

---
 clang/test/CodeGen/allow-ubsan-check.c        |   6 +-
 .../RelativeVTablesABI/dynamic-cast.cpp       |   6 +-
 .../RelativeVTablesABI/type-info.cpp          |   2 +-
 .../CodeGenOpenCL/amdgcn-buffer-rsrc-type.cl  |   4 +-
 clang/test/CodeGenOpenCL/as_type.cl           |   2 +-
 llvm/include/llvm/Analysis/CaptureTracking.h  |  52 +++++--
 llvm/include/llvm/Support/ModRef.h            |  20 +++
 llvm/lib/Analysis/CaptureTracking.cpp         | 111 ++++++++------
 llvm/lib/Analysis/InstructionSimplify.cpp     |   8 +-
 .../Transforms/IPO/AttributorAttributes.cpp   |  36 ++---
 llvm/lib/Transforms/IPO/FunctionAttrs.cpp     | 143 ++++++++++++------
 .../InstCombine/InstCombineCompares.cpp       |   8 +-
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp |  46 +++---
 .../FunctionAttrs/2009-01-02-LocalStores.ll   |   2 +-
 .../Transforms/FunctionAttrs/arg_returned.ll  |  24 +--
 .../Transforms/FunctionAttrs/nocapture.ll     |  59 +++++---
 llvm/test/Transforms/FunctionAttrs/nonnull.ll |  28 ++--
 llvm/test/Transforms/FunctionAttrs/noundef.ll |   8 +-
 .../Transforms/FunctionAttrs/readattrs.ll     |   8 +-
 llvm/test/Transforms/FunctionAttrs/stats.ll   |   4 +-
 .../AArch64/block_scaling_decompr_8bit.ll     |   2 +-
 .../PhaseOrdering/bitcast-store-branch.ll     |   2 +-
 .../dce-after-argument-promotion-loads.ll     |   2 +-
 .../enable-loop-header-duplication-oz.ll      |   4 +-
 .../Analysis/CaptureTrackingTest.cpp          |   5 +-
 25 files changed, 360 insertions(+), 232 deletions(-)

diff --git a/clang/test/CodeGen/allow-ubsan-check.c b/clang/test/CodeGen/allow-ubsan-check.c
index 0cd81a77f5cc594..c116604288546dc 100644
--- a/clang/test/CodeGen/allow-ubsan-check.c
+++ b/clang/test/CodeGen/allow-ubsan-check.c
@@ -86,7 +86,7 @@ int div(int x, int y) {
 }
 
 // CHECK-LABEL: define dso_local i32 @null(
-// CHECK-SAME: ptr noundef readonly [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-SAME: ptr noundef readonly captures(address_is_null) [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[TMP0:%.*]] = icmp eq ptr [[X]], null, !nosanitize [[META2]]
 //
@@ -102,7 +102,7 @@ int div(int x, int y) {
 // CHECK-NEXT:    ret i32 [[TMP2]]
 //
 // TR-LABEL: define dso_local i32 @null(
-// TR-SAME: ptr noundef readonly [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// TR-SAME: ptr noundef readonly captures(address_is_null) [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // TR-NEXT:  [[ENTRY:.*:]]
 // TR-NEXT:    [[TMP0:%.*]] = icmp eq ptr [[X]], null, !nosanitize [[META2]]
 // TR-NEXT:    [[TMP1:%.*]] = tail call i1 @llvm.allow.ubsan.check(i8 29), !nosanitize [[META2]]
@@ -116,7 +116,7 @@ int div(int x, int y) {
 // TR-NEXT:    ret i32 [[TMP2]]
 //
 // REC-LABEL: define dso_local i32 @null(
-// REC-SAME: ptr noundef readonly [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// REC-SAME: ptr noundef readonly captures(address_is_null) [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // REC-NEXT:  [[ENTRY:.*:]]
 // REC-NEXT:    [[TMP0:%.*]] = icmp eq ptr [[X]], null, !nosanitize [[META2]]
 // REC-NEXT:    [[TMP1:%.*]] = tail call i1 @llvm.allow.ubsan.check(i8 29), !nosanitize [[META2]]
diff --git a/clang/test/CodeGenCXX/RelativeVTablesABI/dynamic-cast.cpp b/clang/test/CodeGenCXX/RelativeVTablesABI/dynamic-cast.cpp
index 83daf57be22ffcc..3662a270713b697 100644
--- a/clang/test/CodeGenCXX/RelativeVTablesABI/dynamic-cast.cpp
+++ b/clang/test/CodeGenCXX/RelativeVTablesABI/dynamic-cast.cpp
@@ -3,7 +3,7 @@
 
 // RUN: %clang_cc1 %s -triple=aarch64-unknown-fuchsia -O3 -o - -emit-llvm | FileCheck %s
 
-// CHECK:      define{{.*}} ptr @_Z6upcastP1B(ptr noundef readnone returned %b) local_unnamed_addr
+// CHECK:      define{{.*}} ptr @_Z6upcastP1B(ptr noundef readnone returned captures(ret: address, provenance) %b) local_unnamed_addr
 // CHECK-NEXT: entry:
 // CHECK-NEXT:   ret ptr %b
 // CHECK-NEXT: }
@@ -22,12 +22,12 @@
 
 // CHECK: declare ptr @__dynamic_cast(ptr, ptr, ptr, i64) local_unnamed_addr
 
-// CHECK:      define{{.*}} ptr @_Z8selfcastP1B(ptr noundef readnone returned %b) local_unnamed_addr
+// CHECK:      define{{.*}} ptr @_Z8selfcastP1B(ptr noundef readnone returned captures(ret: address, provenance) %b) local_unnamed_addr
 // CHECK-NEXT: entry
 // CHECK-NEXT:   ret ptr %b
 // CHECK-NEXT: }
 
-// CHECK: define{{.*}} ptr @_Z9void_castP1B(ptr noundef readonly %b) local_unnamed_addr
+// CHECK: define{{.*}} ptr @_Z9void_castP1B(ptr noundef readonly captures(address_is_null, ret: address, provenance) %b) local_unnamed_addr
 // CHECK-NEXT: entry:
 // CHECK-NEXT:   [[isnull:%[0-9]+]] = icmp eq ptr %b, null
 // CHECK-NEXT:   br i1 [[isnull]], label %[[dynamic_cast_end:[a-z0-9._]+]], label %[[dynamic_cast_notnull:[a-z0-9._]+]]
diff --git a/clang/test/CodeGenCXX/RelativeVTablesABI/type-info.cpp b/clang/test/CodeGenCXX/RelativeVTablesABI/type-info.cpp
index c471e5dbd7b33ce..2a838708ca23152 100644
--- a/clang/test/CodeGenCXX/RelativeVTablesABI/type-info.cpp
+++ b/clang/test/CodeGenCXX/RelativeVTablesABI/type-info.cpp
@@ -24,7 +24,7 @@
 // CHECK-NEXT:   ret ptr @_ZTS1A
 // CHECK-NEXT: }
 
-// CHECK:      define{{.*}} i1 @_Z5equalP1A(ptr noundef readonly %a) local_unnamed_addr
+// CHECK:      define{{.*}} i1 @_Z5equalP1A(ptr noundef readonly captures(address_is_null) %a) local_unnamed_addr
 // CHECK-NEXT: entry:
 // CHECK-NEXT:   [[isnull:%[0-9]+]] = icmp eq ptr %a, null
 // CHECK-NEXT:   br i1 [[isnull]], label %[[bad_typeid:[a-z0-9._]+]], label %[[end:[a-z0-9.+]+]]
diff --git a/clang/test/CodeGenOpenCL/amdgcn-buffer-rsrc-type.cl b/clang/test/CodeGenOpenCL/amdgcn-buffer-rsrc-type.cl
index 0aadaad2dca5c52..62fd20c4d141455 100644
--- a/clang/test/CodeGenOpenCL/amdgcn-buffer-rsrc-type.cl
+++ b/clang/test/CodeGenOpenCL/amdgcn-buffer-rsrc-type.cl
@@ -22,7 +22,7 @@ __amdgpu_buffer_rsrc_t getBuffer(void *p) {
 }
 
 // CHECK-LABEL: define {{[^@]+}}@consumeBufferPtr
-// CHECK-SAME: (ptr addrspace(5) noundef readonly [[P:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-SAME: (ptr addrspace(5) noundef readonly captures(address) [[P:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  entry:
 // CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq ptr addrspace(5) [[P]], addrspacecast (ptr null to ptr addrspace(5))
 // CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[IF_END:%.*]], label [[IF_THEN:%.*]]
@@ -39,7 +39,7 @@ void consumeBufferPtr(__amdgpu_buffer_rsrc_t *p) {
 }
 
 // CHECK-LABEL: define {{[^@]+}}@test
-// CHECK-SAME: (ptr addrspace(5) noundef readonly [[A:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-SAME: (ptr addrspace(5) noundef readonly captures(address) [[A:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  entry:
 // CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr addrspace(5) [[A]], align 16, !tbaa [[TBAA8:![0-9]+]]
 // CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i32 [[TMP0]], 0
diff --git a/clang/test/CodeGenOpenCL/as_type.cl b/clang/test/CodeGenOpenCL/as_type.cl
index 1fe26fbeafdb4b8..2c6cdc3810b4da4 100644
--- a/clang/test/CodeGenOpenCL/as_type.cl
+++ b/clang/test/CodeGenOpenCL/as_type.cl
@@ -67,7 +67,7 @@ int3 f8(char16 x) {
   return __builtin_astype(x, int3);
 }
 
-//CHECK: define{{.*}} spir_func noundef ptr addrspace(1) @addr_cast(ptr noundef readnone %[[x:.*]])
+//CHECK: define{{.*}} spir_func noundef ptr addrspace(1) @addr_cast(ptr noundef readnone captures(ret: address, provenance) %[[x:.*]])
 //CHECK: %[[cast:.*]] ={{.*}} addrspacecast ptr %[[x]] to ptr addrspace(1)
 //CHECK: ret ptr addrspace(1) %[[cast]]
 global int* addr_cast(int *x) {
diff --git a/llvm/include/llvm/Analysis/CaptureTracking.h b/llvm/include/llvm/Analysis/CaptureTracking.h
index 06a00d9ae789908..4e09c1e6f3021dc 100644
--- a/llvm/include/llvm/Analysis/CaptureTracking.h
+++ b/llvm/include/llvm/Analysis/CaptureTracking.h
@@ -14,11 +14,13 @@
 #define LLVM_ANALYSIS_CAPTURETRACKING_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ModRef.h"
 
 namespace llvm {
 
   class Value;
   class Use;
+  class CaptureInfo;
   class DataLayout;
   class Instruction;
   class DominatorTree;
@@ -94,10 +96,38 @@ namespace llvm {
     /// U->getUser() is always an Instruction.
     virtual bool shouldExplore(const Use *U);
 
-    /// captured - Information about the pointer was captured by the user of
-    /// use U. Return true to stop the traversal or false to continue looking
-    /// for more capturing instructions.
-    virtual bool captured(const Use *U) = 0;
+    /// When returned from captures(), stop the traversal.
+    static std::optional<CaptureComponents> stop() { return std::nullopt; }
+
+    /// When returned from captures(), continue traversal, but do not follow
+    /// the return value of this user, even if it has additional capture
+    /// components. Should only be used if captures() has already taken the
+    /// potential return caputres into account.
+    static std::optional<CaptureComponents> continueIgnoringReturn() {
+      return CaptureComponents::None;
+    }
+
+    /// When returned from captures(), continue traversal, and also follow
+    /// the return value of this user if it has additional capture components
+    /// (that is, capture components in Ret that are not part of Other).
+    static std::optional<CaptureComponents> continueDefault(CaptureInfo CI) {
+      CaptureComponents RetCC = CI.getRetComponents();
+      if (!capturesNothing(RetCC & ~CI.getOtherComponents()))
+        return RetCC;
+      return CaptureComponents::None;
+    }
+
+    /// Use U directly captures CI.getOtherComponents() and additionally
+    /// CI.getRetComponents() through the return value of the user of U.
+    ///
+    /// Return std::nullopt to stop the traversal, or the CaptureComponents to
+    /// follow via the return value, which must be a subset of
+    /// CI.getRetComponents().
+    ///
+    /// For convenience, prefer returning one of stop(), continueDefault(CI) or
+    /// continueIgnoringReturn().
+    virtual std::optional<CaptureComponents> captured(const Use *U,
+                                                      CaptureInfo CI) = 0;
 
     /// isDereferenceableOrNull - Overload to allow clients with additional
     /// knowledge about pointer dereferenceability to provide it and thereby
@@ -105,20 +135,14 @@ namespace llvm {
     virtual bool isDereferenceableOrNull(Value *O, const DataLayout &DL);
   };
 
-  /// Types of use capture kinds, see \p DetermineUseCaptureKind.
-  enum class UseCaptureKind {
-    NO_CAPTURE,
-    MAY_CAPTURE,
-    PASSTHROUGH,
-  };
-
   /// Determine what kind of capture behaviour \p U may exhibit.
   ///
-  /// A use can be no-capture, a use can potentially capture, or a use can be
-  /// passthrough such that the uses of the user or \p U should be inspected.
+  /// The Other part of the returned CaptureInfo indicates which component of
+  /// the pointer may be captured directly by the use. The Ret part indicates
+  /// which components may be captured by following uses of the user of \p U.
   /// The \p IsDereferenceableOrNull callback is used to rule out capturing for
   /// certain comparisons.
-  UseCaptureKind
+  CaptureInfo
   DetermineUseCaptureKind(const Use &U,
                           llvm::function_ref<bool(Value *, const DataLayout &)>
                               IsDereferenceableOrNull);
diff --git a/llvm/include/llvm/Support/ModRef.h b/llvm/include/llvm/Support/ModRef.h
index a8ce9a8e6e69c47..716ed2cb8cd4871 100644
--- a/llvm/include/llvm/Support/ModRef.h
+++ b/llvm/include/llvm/Support/ModRef.h
@@ -309,6 +309,10 @@ inline bool capturesFullProvenance(CaptureComponents CC) {
   return (CC & CaptureComponents::Provenance) == CaptureComponents::Provenance;
 }
 
+inline bool capturesAll(CaptureComponents CC) {
+  return CC == CaptureComponents::All;
+}
+
 raw_ostream &operator<<(raw_ostream &OS, CaptureComponents CC);
 
 /// Represents which components of the pointer may be captured in which
@@ -333,6 +337,22 @@ class CaptureInfo {
   /// Create CaptureInfo that may capture all components of the pointer.
   static CaptureInfo all() { return CaptureInfo(CaptureComponents::All); }
 
+  /// Create CaptureInfo that may only capture through means other than the
+  /// return value.
+  static CaptureInfo
+  otherOnly(CaptureComponents OtherComponents = CaptureComponents::All) {
+    return CaptureInfo(OtherComponents, CaptureComponents::None);
+  }
+
+  /// Create CaptureInfo that may only capture via the return value.
+  static CaptureInfo
+  retOnly(CaptureComponents RetComponents = CaptureComponents::All) {
+    return CaptureInfo(CaptureComponents::None, RetComponents);
+  }
+
+  /// Whether the pointer is only captured via the return value.
+  bool isRetOnly() const { return capturesNothing(OtherComponents); }
+
   /// Get components potentially captured by the return value.
   CaptureComponents getRetComponents() const { return RetComponents; }
 
diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp
index 49baf2eb84bb3e1..4e403b8825c7f71 100644
--- a/llvm/lib/Analysis/CaptureTracking.cpp
+++ b/llvm/lib/Analysis/CaptureTracking.cpp
@@ -81,14 +81,16 @@ struct SimpleCaptureTracker : public CaptureTracker {
     Captured = true;
   }
 
-  bool captured(const Use *U) override {
+  std::optional<CaptureComponents> captured(const Use *U,
+                                            CaptureInfo CI) override {
+    // TODO(captures): Use CaptureInfo.
     if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures)
-      return false;
+      return continueIgnoringReturn();
 
     LLVM_DEBUG(dbgs() << "Captured by: " << *U->getUser() << "\n");
 
     Captured = true;
-    return true;
+    return stop();
   }
 
   bool ReturnCaptures;
@@ -122,19 +124,22 @@ struct CapturesBefore : public CaptureTracker {
     return !isPotentiallyReachable(I, BeforeHere, nullptr, DT, LI);
   }
 
-  bool captured(const Use *U) override {
+  std::optional<CaptureComponents> captured(const Use *U,
+                                            CaptureInfo CI) override {
+    // TODO(captures): Use CaptureInfo.
     Instruction *I = cast<Instruction>(U->getUser());
     if (isa<ReturnInst>(I) && !ReturnCaptures)
-      return false;
+      return continueIgnoringReturn();
 
     // Check isSafeToPrune() here rather than in shouldExplore() to avoid
     // an expensive reachability query for every instruction we look at.
     // Instead we only do one for actual capturing candidates.
     if (isSafeToPrune(I))
-      return false;
+      // If the use is not reachable, the instruction result isn't either.
+      return continueIgnoringReturn();
 
     Captured = true;
-    return true;
+    return stop();
   }
 
   const Instruction *BeforeHere;
@@ -166,10 +171,12 @@ struct EarliestCaptures : public CaptureTracker {
     EarliestCapture = &*F.getEntryBlock().begin();
   }
 
-  bool captured(const Use *U) override {
+  std::optional<CaptureComponents> captured(const Use *U,
+                                            CaptureInfo CI) override {
+    // TODO(captures): Use CaptureInfo.
     Instruction *I = cast<Instruction>(U->getUser());
     if (isa<ReturnInst>(I) && !ReturnCaptures)
-      return false;
+      return continueIgnoringReturn();
 
     if (!EarliestCapture)
       EarliestCapture = I;
@@ -177,9 +184,10 @@ struct EarliestCaptures : public CaptureTracker {
       EarliestCapture = DT.findNearestCommonDominator(EarliestCapture, I);
     Captured = true;
 
-    // Return false to continue analysis; we need to see all potential
-    // captures.
-    return false;
+    // Continue analysis, as we need to see all potential captures. However,
+    // we do not need to follow the instruction result, as this use will
+    // dominate any captures made through the instruction result..
+    return continueIgnoringReturn();
   }
 
   Instruction *EarliestCapture = nullptr;
@@ -274,25 +282,26 @@ Instruction *llvm::FindEarliestCapture(const Value *V, Function &F,
   return CB.EarliestCapture;
 }
 
-UseCaptureKind llvm::DetermineUseCaptureKind(
+CaptureInfo llvm::DetermineUseCaptureKind(
     const Use &U,
     function_ref<bool(Value *, const DataLayout &)> IsDereferenceableOrNull) {
   Instruction *I = dyn_cast<Instruction>(U.getUser());
 
   // TODO: Investigate non-instruction uses.
   if (!I)
-    return UseCaptureKind::MAY_CAPTURE;
+    return CaptureInfo::otherOnly();
 
   switch (I->getOpcode()) {
   case Instruction::Call:
   case Instruction::Invoke: {
+    // TODO(captures): Make this more precise.
     auto *Call = cast<CallBase>(I);
     // Not captured if the callee is readonly, doesn't return a copy through
     // its return value and doesn't unwind (a readonly function can leak bits
     // by throwing an exception or not depending on the input value).
     if (Call->onlyReadsMemory() && Call->doesNotThrow() &&
         Call->getType()->isVoidTy())
-      return UseCaptureKind::NO_CAPTURE;
+      return CaptureInfo::none();
 
     // The pointer is not captured if returned pointer is not captured.
     // NOTE: CaptureTracking users should not assume that only functions
@@ -300,13 +309,13 @@ UseCaptureKind llvm::DetermineUseCaptureKind(
     // getUnderlyingObject in ValueTracking or DecomposeGEPExpression
     // in BasicAA also need to know about this property.
     if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call, true))
-      return UseCaptureKind::PASSTHROUGH;
+      return CaptureInfo::retOnly();
 
     // Volatile operations effectively capture the memory location that they
     // load and store to.
     if (auto *MI = dyn_cast<MemIntrinsic>(Call))
       if (MI->isVolatile())
-        return UseCaptureKind::MAY_CAPTURE;
+        return CaptureInfo::otherOnly();
 
     // Calling a function pointer does not in itself cause the pointer to
     // be captured.  This is a subtle point considering that (for example)
@@ -315,30 +324,26 @@ UseCaptureKind llvm::DetermineUseCaptureKind(
     // captured, even though the loaded value might be the pointer itself
     // (think of self-referential objects).
     if (Call->isCallee(&U))
-      return UseCaptureKind::NO_CAPTURE;
+      return CaptureInfo::none();
 
     // Not captured if only passed via 'nocapture' arguments.
     assert(Call->isDataOperand(&U) && "Non-callee must be data operand");
-    if (!Call->doesNotCapture(Call->getDataOperandNo(&U))) {
-      // The parameter is not marked 'nocapture' - captured.
-      return UseCaptureKind::MAY_CAPTURE;
-    }
-    return UseCaptureKind::NO_CAPTURE;
+    return Call->getCaptureInfo(Call->getDataOperandNo(&U));
   }
   case Instruction::Load:
     // Volatile loads make the address observable.
     if (cast<LoadInst>(I)->isVolatile())
-      return UseCaptureKind::MAY_CAPTURE;
-    return UseCaptureKind::NO_CAPTURE;
+      return CaptureInfo::otherOnly();
+    return CaptureInfo::none();
   case Instruction::VAArg:
     // "va-arg" from a pointer does not cause it to be captured.
-    return UseCaptureKind::NO_CAPTURE;
+    return CaptureInfo::none();
   case Instruction::Store:
     // Stored the pointer - conservatively assume it may be captured.
     // Volatile stores make the address observable.
     if (U.getOperandNo() == 0 || cast<StoreInst>(I)->isVolatile())
-      return UseCaptureKind::MAY_CAPTURE;
-    return UseCaptureKind::NO_CAPTURE;
+      return CaptureInfo::otherOnly();
+    return CaptureInfo::none();
   case Instruction::AtomicRMW: {
     // atomicrmw conceptually includes both a load and store from
     // the same location.
@@ -347,8 +352,8 @@ UseCaptureKind llvm::DetermineUseCaptureKind(
     // Volatile stores make the address observable.
     auto *ARMWI = cast<AtomicRMWInst>(I);
     if (U.getOperandNo() == 1 || ARMWI->isVolatile())
-      return UseCaptureKind::MAY_CAPTURE;
-    return UseCaptureKind::NO_CAPTURE;
+      return CaptureInfo::otherOnly();
+    return CaptureInfo::none();
   }
   case Instruction::AtomicCmpXchg: {
     // cmpxchg conceptually includes both a load and store from
@@ -358,31 +363,34 @@ UseCaptureKind llvm::DetermineUseCaptureKind(
     // Volatile stores make the address observable.
     auto *ACXI = cast<AtomicCmpXchgInst>(I);
     if (U.getOperandNo() == 1 || U.getOperandNo() == 2 || ACXI->isVolatile())
-      return UseCaptureKind::MAY_CAPTURE;
-    return UseCaptureKind::NO_CAPTURE;
+      return CaptureInfo::otherOnly();
+    return CaptureInfo::none();
   }
   case Instruction::GetElementPtr:
     // AA does not support pointers of vectors, so GEP vector splats need to
     // be considered as captures.
     if (I->getType()->isVectorTy())
-      return UseCaptureKind::MAY_CAPTURE;
-    return UseCaptureKind::PASSTHROUGH;
+      return CaptureInfo::otherOnly();
+    return CaptureInfo::retOnly();
   case Instruction::BitCast:
   case Instruction::PHI:
   case Instruction::Select:
   case Instruction::AddrSpaceCast:
     // The original value is not captured via this if the new value isn't.
-    return UseCaptureKind::PASSTHROUGH;
+    return CaptureInfo::retOnly();
   case Instruction::ICmp: {
     unsigned Idx = U.getOperandNo();
     unsigned OtherIdx = 1 - Idx;
     if (auto *CPN = dyn_cast<ConstantPointerNull>(I->getOperand(OtherIdx))) {
+      // TODO(captures): Remove these special cases once we make use of
+      // captures(address_is_null).
+
       // Don't count comparisons of a no-alias return value against null as
       // captures. This allows us to ignore comparisons of malloc results
       // with null, for example.
       if (CPN->getType()->getAddressSpace() == 0)
         if (isNoAliasCall(U.get()->stripPointerCasts()))
-          return UseCaptureKind::NO_CAPTURE;
+          return CaptureInfo::none();
       if (!I->getFunction()->nullPointerIsDefined()) {
         auto *O = I->getOperand(Idx)->stripPointerCastsSameRepresentation();
         // Comparing a dereferenceable_or_null pointer against null cannot
@@ -390,17 +398,19 @@ UseCaptureKind llvm::DetermineUseCaptureKind(
         // valid (in-bounds) pointer.
         const DataLayout &DL = I->getDataLayout();
         if (IsDereferenceableOrNull && IsDereferenceableOrNull(O, DL))
-          return UseCaptureKind::NO_CAPTURE;
+          return CaptureInfo::none();
       }
+      return CaptureInfo::otherOnly(CaptureComponents::AddressIsNull);
     }
 
     // Otherwise, be conservative. There are crazy ways to capture pointers
-    // using comparisons.
-    return UseCaptureKind::MAY_CAPTURE;
+    // using comparisons. However, only the address is captured, not the
+    // provenance.
+    return CaptureInfo::otherOnly(CaptureComponents::Address);
   }
   default:
     // Something else - be conservative and say it is captured.
-    return UseCaptureKind::MAY_CAPTURE;
+    return CaptureInfo::otherOnly();
   }
 }
 
@@ -438,18 +448,21 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker,
   };
   while (!Worklist.empty()) {
     const Use *U = Worklist.pop_back_val();
-    switch (DetermineUseCaptureKind(*U, IsDereferenceableOrNull)) {
-    case UseCaptureKind::NO_CAPTURE:
+    CaptureInfo CI = DetermineUseCaptureKind(*U, IsDereferenceableOrNull);
+    if (capturesNothing(CI))
       continue;
-    case UseCaptureKind::MAY_CAPTURE:
-      if (Tracker->captured(U))
+    CaptureComponents RetCC = CI.getRetComponents();
+    if (!capturesNothing(CI.getOtherComponents())) {
+      std::optional<CaptureComponents> Res = Tracker->captured(U, CI);
+      if (!Res)
         return;
-      continue;
-    case UseCaptureKind::PASSTHROUGH:
-      if (!AddUses(U->getUser()))
-        return;
-      continue;
+      assert(capturesNothing(*Res & ~RetCC) &&
+             "captures() result must be subset of getRetComponents()");
+      RetCC = *Res;
     }
+    // TODO(captures): We could keep track of RetCC for the users.
+    if (!capturesNothing(RetCC) && !AddUses(U->getUser()))
+      return;
   }
 
   // All uses examined.
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 3cbc4107433ef3d..58fa165341f09a3 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -2801,7 +2801,9 @@ static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS,
       struct CustomCaptureTracker : public CaptureTracker {
         bool Captured = false;
         void tooManyUses() override { Captured = true; }
-        bool captured(const Use *U) override {
+        std::optional<CaptureComponents> captured(const Use *U,
+                                                  CaptureInfo CI) override {
+          // TODO(captures): Use CaptureInfo.
           if (auto *ICmp = dyn_cast<ICmpInst>(U->getUser())) {
             // Comparison against value stored in global variable. Given the
             // pointer does not escape, its value cannot be guessed and stored
@@ -2809,11 +2811,11 @@ static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS,
             unsigned OtherIdx = 1 - U->getOperandNo();
             auto *LI = dyn_cast<LoadInst>(ICmp->getOperand(OtherIdx));
             if (LI && isa<GlobalVariable>(LI->getPointerOperand()))
-              return false;
+              return continueDefault(CI);
           }
 
           Captured = true;
-          return true;
+          return stop();
         }
       };
       CustomCaptureTracker Tracker;
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 17e7fada1082762..0fcb9739385b88f 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -3970,18 +3970,16 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
       // TODO: We should track the capturing uses in AANoCapture but the problem
       //       is CGSCC runs. For those we would need to "allow" AANoCapture for
       //       a value in the module slice.
-      switch (DetermineUseCaptureKind(U, IsDereferenceableOrNull)) {
-      case UseCaptureKind::NO_CAPTURE:
+      // TODO(captures): Make this more precise.
+      CaptureInfo CI = DetermineUseCaptureKind(U, IsDereferenceableOrNull);
+      if (capturesNothing(CI))
         return true;
-      case UseCaptureKind::MAY_CAPTURE:
-        LLVM_DEBUG(dbgs() << "[AANoAliasCSArg] Unknown user: " << *UserI
-                          << "\n");
-        return false;
-      case UseCaptureKind::PASSTHROUGH:
+      if (CI.isRetOnly()) {
         Follow = true;
         return true;
       }
-      llvm_unreachable("unknown UseCaptureKind");
+      LLVM_DEBUG(dbgs() << "[AANoAliasCSArg] Unknown user: " << *UserI << "\n");
+      return false;
     };
 
     bool IsKnownNoCapture;
@@ -6019,16 +6017,15 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) {
   };
 
   auto UseCheck = [&](const Use &U, bool &Follow) -> bool {
-    switch (DetermineUseCaptureKind(U, IsDereferenceableOrNull)) {
-    case UseCaptureKind::NO_CAPTURE:
+    // TODO(captures): Make this more precise.
+    CaptureInfo CI = DetermineUseCaptureKind(U, IsDereferenceableOrNull);
+    if (capturesNothing(CI))
       return true;
-    case UseCaptureKind::MAY_CAPTURE:
-      return checkUse(A, T, U, Follow);
-    case UseCaptureKind::PASSTHROUGH:
+    if (CI.isRetOnly()) {
       Follow = true;
       return true;
     }
-    llvm_unreachable("Unexpected use capture kind!");
+    return checkUse(A, T, U, Follow);
   };
 
   if (!A.checkForAllUses(UseCheck, *this, *V))
@@ -12151,16 +12148,13 @@ struct AAGlobalValueInfoFloating : public AAGlobalValueInfo {
 
     auto UsePred = [&](const Use &U, bool &Follow) -> bool {
       Uses.insert(&U);
-      switch (DetermineUseCaptureKind(U, nullptr)) {
-      case UseCaptureKind::NO_CAPTURE:
-        return checkUse(A, U, Follow, Worklist);
-      case UseCaptureKind::MAY_CAPTURE:
-        return checkUse(A, U, Follow, Worklist);
-      case UseCaptureKind::PASSTHROUGH:
+      // TODO(captures): Make this more precise.
+      CaptureInfo CI = DetermineUseCaptureKind(U, nullptr);
+      if (!capturesNothing(CI) && CI.isRetOnly()) {
         Follow = true;
         return true;
       }
-      return true;
+      return checkUse(A, U, Follow, Worklist);
     };
     auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) {
       Uses.insert(&OldU);
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index cf56f67e4de3f53..920ccf28b3f1b04 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -71,7 +71,9 @@ using namespace llvm;
 #define DEBUG_TYPE "function-attrs"
 
 STATISTIC(NumMemoryAttr, "Number of functions with improved memory attribute");
-STATISTIC(NumNoCapture, "Number of arguments marked nocapture");
+STATISTIC(NumCapturesNone, "Number of arguments marked captures(none)");
+STATISTIC(NumCapturesOther, "Number of arguments marked with captures "
+                            "attribute other than captures(none)");
 STATISTIC(NumReturned, "Number of arguments marked returned");
 STATISTIC(NumReadNoneArg, "Number of arguments marked readnone");
 STATISTIC(NumReadOnlyArg, "Number of arguments marked readonly");
@@ -108,6 +110,13 @@ static cl::opt<bool> DisableThinLTOPropagation(
     "disable-thinlto-funcattrs", cl::init(true), cl::Hidden,
     cl::desc("Don't propagate function-attrs in thinLTO"));
 
+static void addCapturesStat(CaptureInfo CI) {
+  if (capturesNothing(CI))
+    ++NumCapturesNone;
+  else
+    ++NumCapturesOther;
+}
+
 namespace {
 
 using SCCNodeSet = SmallSetVector<Function *, 8>;
@@ -494,6 +503,9 @@ namespace {
 /// SCC of the arguments.
 struct ArgumentGraphNode {
   Argument *Definition;
+  /// CaptureComponents for this argument, excluding captures via Uses.
+  /// We don't distinguish between other/return captures here.
+  CaptureComponents CC = CaptureComponents::None;
   SmallVector<ArgumentGraphNode *, 4> Uses;
 };
 
@@ -535,18 +547,37 @@ class ArgumentGraph {
 struct ArgumentUsesTracker : public CaptureTracker {
   ArgumentUsesTracker(const SCCNodeSet &SCCNodes) : SCCNodes(SCCNodes) {}
 
-  void tooManyUses() override { Captured = true; }
+  void tooManyUses() override { CI = CaptureInfo::all(); }
+
+  std::optional<CaptureComponents> captured(const Use *U,
+                                            CaptureInfo UseCI) override {
+    if (updateCaptureInfo(U, UseCI.getOtherComponents())) {
+      // Don't bother continuing if we already capture everything.
+      if (capturesAll(CI.getOtherComponents()))
+        return stop();
+      return continueDefault(UseCI);
+    }
+
+    // For SCC argument tracking, we're not going to analyze other/ret
+    // components separately, so don't follow the return value.
+    return continueIgnoringReturn();
+  }
 
-  bool captured(const Use *U) override {
+  bool updateCaptureInfo(const Use *U, CaptureComponents CC) {
     CallBase *CB = dyn_cast<CallBase>(U->getUser());
     if (!CB) {
-      Captured = true;
+      if (isa<ReturnInst>(U->getUser()))
+        CI |= CaptureInfo::retOnly(CC);
+      else
+        // Conservatively assume that the captured value might make its way
+        // into the return value as well. This could be made more precise.
+        CI |= CaptureInfo(CC);
       return true;
     }
 
     Function *F = CB->getCalledFunction();
     if (!F || !F->hasExactDefinition() || !SCCNodes.count(F)) {
-      Captured = true;
+      CI |= CaptureInfo(CC);
       return true;
     }
 
@@ -560,22 +591,24 @@ struct ArgumentUsesTracker : public CaptureTracker {
       // use.  In this case it does not matter if the callee is within our SCC
       // or not -- we've been captured in some unknown way, and we have to be
       // conservative.
-      Captured = true;
+      CI |= CaptureInfo(CC);
       return true;
     }
 
     if (UseIndex >= F->arg_size()) {
       assert(F->isVarArg() && "More params than args in non-varargs call");
-      Captured = true;
+      CI |= CaptureInfo(CC);
       return true;
     }
 
+    // TODO(captures): Could improve precision by remembering maximum
+    // capture components for the argument.
     Uses.push_back(&*std::next(F->arg_begin(), UseIndex));
     return false;
   }
 
-  // True only if certainly captured (used outside our SCC).
-  bool Captured = false;
+  // Does not include potential captures via Uses in the SCC.
+  CaptureInfo CI = CaptureInfo::none();
 
   // Uses within our SCC.
   SmallVector<Argument *, 4> Uses;
@@ -1190,6 +1223,15 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
                              bool SkipInitializes) {
   ArgumentGraph AG;
 
+  auto DetermineAccessAttrsForSingleton = [](Argument *A) {
+    SmallPtrSet<Argument *, 8> Self;
+    Self.insert(A);
+    Attribute::AttrKind R = determinePointerAccessAttrs(A, Self);
+    if (R != Attribute::None)
+      return addAccessAttr(A, R);
+    return false;
+  };
+
   // Check each function in turn, determining which pointer arguments are not
   // captured.
   for (Function *F : SCCNodes) {
@@ -1210,7 +1252,7 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
         if (A.getType()->isPointerTy() && !A.hasNoCaptureAttr()) {
           A.addAttr(Attribute::getWithCaptureInfo(A.getContext(),
                                                   CaptureInfo::none()));
-          ++NumNoCapture;
+          ++NumCapturesNone;
           Changed.insert(F);
         }
       }
@@ -1221,21 +1263,23 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
       if (!A.getType()->isPointerTy())
         continue;
       bool HasNonLocalUses = false;
-      if (!A.hasNoCaptureAttr()) {
+      CaptureInfo OrigCI = A.getAttributes().getCaptureInfo();
+      if (!capturesNothing(OrigCI)) {
         ArgumentUsesTracker Tracker(SCCNodes);
         PointerMayBeCaptured(&A, &Tracker);
-        if (!Tracker.Captured) {
+        CaptureInfo NewCI = Tracker.CI & OrigCI;
+        if (NewCI != OrigCI) {
           if (Tracker.Uses.empty()) {
-            // If it's trivially not captured, mark it nocapture now.
-            A.addAttr(Attribute::getWithCaptureInfo(A.getContext(),
-                                                    CaptureInfo::none()));
-            ++NumNoCapture;
+            // If the information is complete, add the attribute now.
+            A.addAttr(Attribute::getWithCaptureInfo(A.getContext(), NewCI));
+            addCapturesStat(NewCI);
             Changed.insert(F);
           } else {
             // If it's not trivially captured and not trivially not captured,
             // then it must be calling into another function in our SCC. Save
             // its particulars for Argument-SCC analysis later.
             ArgumentGraphNode *Node = AG[&A];
+            Node->CC = CaptureComponents(NewCI);
             for (Argument *Use : Tracker.Uses) {
               Node->Uses.push_back(AG[Use]);
               if (Use != &A)
@@ -1250,12 +1294,8 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
         // an SCC? Note that we don't allow any calls at all here, or else our
         // result will be dependent on the iteration order through the
         // functions in the SCC.
-        SmallPtrSet<Argument *, 8> Self;
-        Self.insert(&A);
-        Attribute::AttrKind R = determinePointerAccessAttrs(&A, Self);
-        if (R != Attribute::None)
-          if (addAccessAttr(&A, R))
-            Changed.insert(F);
+        if (DetermineAccessAttrsForSingleton(&A))
+          Changed.insert(F);
       }
       if (!SkipInitializes && !A.onlyReadsMemory()) {
         if (inferInitializes(A, *F))
@@ -1281,17 +1321,15 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
       if (ArgumentSCC[0]->Uses.size() == 1 &&
           ArgumentSCC[0]->Uses[0] == ArgumentSCC[0]) {
         Argument *A = ArgumentSCC[0]->Definition;
-        A->addAttr(Attribute::getWithCaptureInfo(A->getContext(),
-                                                 CaptureInfo::none()));
-        ++NumNoCapture;
-        Changed.insert(A->getParent());
+        CaptureInfo NewCI = ArgumentSCC[0]->CC;
+        if (NewCI != A->getAttributes().getCaptureInfo()) {
+          A->addAttr(Attribute::getWithCaptureInfo(A->getContext(), NewCI));
+          addCapturesStat(NewCI);
+          Changed.insert(A->getParent());
+        }
 
-        // Infer the access attributes given the new nocapture one
-        SmallPtrSet<Argument *, 8> Self;
-        Self.insert(&*A);
-        Attribute::AttrKind R = determinePointerAccessAttrs(&*A, Self);
-        if (R != Attribute::None)
-          addAccessAttr(A, R);
+        // Infer the access attributes given the new captures one
+        DetermineAccessAttrsForSingleton(A);
       }
       continue;
     }
@@ -1303,27 +1341,42 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
       ArgumentSCCNodes.insert(I->Definition);
     }
 
-    bool SCCCaptured = false;
+    // At the SCC level, only track merged CaptureComponents. We're not
+    // currently prepared to handle propagation of return-only captures across
+    // the SCC.
+    CaptureComponents CC = CaptureComponents::None;
     for (ArgumentGraphNode *N : ArgumentSCC) {
       for (ArgumentGraphNode *Use : N->Uses) {
         Argument *A = Use->Definition;
-        if (A->hasNoCaptureAttr() || ArgumentSCCNodes.count(A))
-          continue;
-        SCCCaptured = true;
+        if (ArgumentSCCNodes.count(A))
+          CC |= Use->CC;
+        else
+          CC |= CaptureComponents(A->getAttributes().getCaptureInfo());
         break;
       }
-      if (SCCCaptured)
+      if (capturesAll(CC))
         break;
     }
-    if (SCCCaptured)
-      continue;
 
-    for (ArgumentGraphNode *N : ArgumentSCC) {
-      Argument *A = N->Definition;
-      A->addAttr(
-          Attribute::getWithCaptureInfo(A->getContext(), CaptureInfo::none()));
-      ++NumNoCapture;
-      Changed.insert(A->getParent());
+    if (!capturesAll(CC)) {
+      for (ArgumentGraphNode *N : ArgumentSCC) {
+        Argument *A = N->Definition;
+        CaptureInfo CI = N->CC | CC;
+        A->addAttr(Attribute::getWithCaptureInfo(A->getContext(), CI));
+        addCapturesStat(CI);
+        Changed.insert(A->getParent());
+      }
+    }
+
+    // TODO(captures): Ignore address-only captures.
+    if (!capturesNothing(CC)) {
+      // As the pointer may be captured, determine the pointer attributes
+      // looking at each argument invidivually.
+      for (ArgumentGraphNode *N : ArgumentSCC) {
+        if (DetermineAccessAttrsForSingleton(N->Definition))
+          Changed.insert(N->Definition->getParent());
+      }
+      continue;
     }
 
     // We also want to compute readonly/readnone/writeonly. With a small number
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 83534059bfb69a0..6640e15f425b90f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -882,7 +882,9 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
 
     void tooManyUses() override { Captured = true; }
 
-    bool captured(const Use *U) override {
+    std::optional<CaptureComponents> captured(const Use *U,
+                                              CaptureInfo CI) override {
+      // TODO(captures): Use CaptureInfo.
       auto *ICmp = dyn_cast<ICmpInst>(U->getUser());
       // We need to check that U is based *only* on the alloca, and doesn't
       // have other contributions from a select/phi operand.
@@ -892,11 +894,11 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
         // Collect equality icmps of the alloca, and don't treat them as
         // captures.
         ICmps[ICmp] |= 1u << U->getOperandNo();
-        return false;
+        return continueDefault(CI);
       }
 
       Captured = true;
-      return true;
+      return stop();
     }
   };
 
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 87b27beb01a0a92..91f8c0101e00c6f 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -1550,32 +1550,32 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
         }
         if (!Visited.insert(&U).second)
           continue;
-        switch (DetermineUseCaptureKind(U, IsDereferenceableOrNull)) {
-        case UseCaptureKind::MAY_CAPTURE:
-          return false;
-        case UseCaptureKind::PASSTHROUGH:
-          // Instructions cannot have non-instruction users.
-          Worklist.push_back(UI);
-          continue;
-        case UseCaptureKind::NO_CAPTURE: {
-          if (UI->isLifetimeStartOrEnd()) {
-            // We note the locations of these intrinsic calls so that we can
-            // delete them later if the optimization succeeds, this is safe
-            // since both llvm.lifetime.start and llvm.lifetime.end intrinsics
-            // practically fill all the bytes of the alloca with an undefined
-            // value, although conceptually marked as alive/dead.
-            int64_t Size = cast<ConstantInt>(UI->getOperand(0))->getSExtValue();
-            if (Size < 0 || Size == DestSize) {
-              LifetimeMarkers.push_back(UI);
-              continue;
-            }
+        CaptureInfo CI = DetermineUseCaptureKind(U, IsDereferenceableOrNull);
+        // TODO(captures): Make this more precise.
+        if (!capturesNothing(CI)) {
+          if (CI.isRetOnly()) {
+            Worklist.push_back(UI);
+            continue;
           }
-          if (UI->hasMetadata(LLVMContext::MD_noalias))
-            NoAliasInstrs.insert(UI);
-          if (!ModRefCallback(UI))
-            return false;
+          return false;
         }
+
+        if (UI->isLifetimeStartOrEnd()) {
+          // We note the locations of these intrinsic calls so that we can
+          // delete them later if the optimization succeeds, this is safe
+          // since both llvm.lifetime.start and llvm.lifetime.end intrinsics
+          // practically fill all the bytes of the alloca with an undefined
+          // value, although conceptually marked as alive/dead.
+          int64_t Size = cast<ConstantInt>(UI->getOperand(0))->getSExtValue();
+          if (Size < 0 || Size == DestSize) {
+            LifetimeMarkers.push_back(UI);
+            continue;
+          }
         }
+        if (UI->hasMetadata(LLVMContext::MD_noalias))
+          NoAliasInstrs.insert(UI);
+        if (!ModRefCallback(UI))
+          return false;
       }
     }
     return true;
diff --git a/llvm/test/Transforms/FunctionAttrs/2009-01-02-LocalStores.ll b/llvm/test/Transforms/FunctionAttrs/2009-01-02-LocalStores.ll
index f706184f9727e24..a3b065667702f0e 100644
--- a/llvm/test/Transforms/FunctionAttrs/2009-01-02-LocalStores.ll
+++ b/llvm/test/Transforms/FunctionAttrs/2009-01-02-LocalStores.ll
@@ -14,7 +14,7 @@ define ptr @b(ptr %q) {
 	ret ptr %tmp
 }
 
-; CHECK: define ptr @c(ptr readnone returned %r)
+; CHECK: define ptr @c(ptr readnone returned captures(address_is_null, ret: address, provenance) %r)
 @g = global i32 0
 define ptr @c(ptr %r) {
 	%a = icmp eq ptr %r, null
diff --git a/llvm/test/Transforms/FunctionAttrs/arg_returned.ll b/llvm/test/Transforms/FunctionAttrs/arg_returned.ll
index 13954694eefe0c7..99406696d33d11c 100644
--- a/llvm/test/Transforms/FunctionAttrs/arg_returned.ll
+++ b/llvm/test/Transforms/FunctionAttrs/arg_returned.ll
@@ -145,8 +145,8 @@ return:                                           ; preds = %cond.end, %if.then3
 
 ; TEST SCC test returning a pointer value argument
 ;
-; FNATTR: define ptr @ptr_sink_r0(ptr readnone returned %r)
-; FNATTR: define ptr @ptr_scc_r1(ptr %a, ptr readnone %r, ptr readnone captures(none) %b)
+; FNATTR: define ptr @ptr_sink_r0(ptr readnone returned captures(ret: address, provenance) %r)
+; FNATTR: define ptr @ptr_scc_r1(ptr readnone %a, ptr readnone %r, ptr readnone captures(none) %b)
 ; FNATTR: define ptr @ptr_scc_r2(ptr readnone %a, ptr readnone %b, ptr readnone %r)
 ;
 ;
@@ -260,8 +260,8 @@ entry:
 
 ; TEST another SCC test
 ;
-; FNATTR:  define ptr @rt2_helper(ptr %a)
-; FNATTR:  define ptr @rt2(ptr readnone %a, ptr readnone %b)
+; FNATTR:  define ptr @rt2_helper(ptr readnone captures(address_is_null) %a)
+; FNATTR:  define ptr @rt2(ptr readnone captures(address_is_null) %a, ptr readnone captures(ret: address, provenance) %b)
 define ptr @rt2_helper(ptr %a) #0 {
 entry:
   %call = call ptr @rt2(ptr %a, ptr %a)
@@ -284,8 +284,8 @@ if.end:
 
 ; TEST another SCC test
 ;
-; FNATTR:  define ptr @rt3_helper(ptr %a, ptr %b)
-; FNATTR:  define ptr @rt3(ptr readnone %a, ptr readnone %b)
+; FNATTR:  define ptr @rt3_helper(ptr readnone captures(address_is_null) %a, ptr readnone %b)
+; FNATTR:  define ptr @rt3(ptr readnone captures(address_is_null) %a, ptr readnone %b)
 define ptr @rt3_helper(ptr %a, ptr %b) #0 {
 entry:
   %call = call ptr @rt3(ptr %a, ptr %b)
@@ -316,7 +316,7 @@ if.end:
 ;  }
 ;
 ;
-; FNATTR:     define ptr @calls_unknown_fn(ptr readnone returned %r)
+; FNATTR:     define ptr @calls_unknown_fn(ptr readnone returned captures(ret: address, provenance) %r)
 declare void @unknown_fn(ptr) #0
 
 define ptr @calls_unknown_fn(ptr %r) #0 {
@@ -415,7 +415,7 @@ if.end:                                           ; preds = %if.then, %entry
 ; }
 ;
 ;
-; FNATTR:     define ptr @bitcast(ptr readnone returned %b)
+; FNATTR:     define ptr @bitcast(ptr readnone returned captures(ret: address, provenance) %b)
 ;
 define ptr @bitcast(ptr %b) #0 {
 entry:
@@ -433,7 +433,7 @@ entry:
 ; }
 ;
 ;
-; FNATTR:     define ptr @bitcasts_select_and_phi(ptr readnone %b)
+; FNATTR:     define ptr @bitcasts_select_and_phi(ptr readnone captures(address_is_null, ret: address, provenance) %b)
 ;
 define ptr @bitcasts_select_and_phi(ptr %b) #0 {
 entry:
@@ -462,7 +462,7 @@ if.end:                                           ; preds = %if.then, %entry
 ; }
 ;
 ;
-; FNATTR:     define ptr @ret_arg_arg_undef(ptr readnone %b)
+; FNATTR:     define ptr @ret_arg_arg_undef(ptr readnone captures(address_is_null, ret: address, provenance) %b)
 ;
 define ptr @ret_arg_arg_undef(ptr %b) #0 {
 entry:
@@ -494,7 +494,7 @@ ret_undef:
 ; }
 ;
 ;
-; FNATTR:     define ptr @ret_undef_arg_arg(ptr readnone %b)
+; FNATTR:     define ptr @ret_undef_arg_arg(ptr readnone captures(address_is_null, ret: address, provenance) %b)
 ;
 define ptr @ret_undef_arg_arg(ptr %b) #0 {
 entry:
@@ -526,7 +526,7 @@ ret_arg1:
 ; }
 ;
 ;
-; FNATTR:     define ptr @ret_undef_arg_undef(ptr readnone %b)
+; FNATTR:     define ptr @ret_undef_arg_undef(ptr readnone captures(address_is_null, ret: address, provenance) %b)
 define ptr @ret_undef_arg_undef(ptr %b) #0 {
 entry:
   %cmp = icmp eq ptr %b, null
diff --git a/llvm/test/Transforms/FunctionAttrs/nocapture.ll b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
index 6164f2adbf5b90e..aeb950bd41e93cb 100644
--- a/llvm/test/Transforms/FunctionAttrs/nocapture.ll
+++ b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
@@ -7,7 +7,7 @@
 define ptr @c1(ptr %q) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define ptr @c1
-; FNATTRS-SAME: (ptr readnone returned [[Q:%.*]]) #[[ATTR0:[0-9]+]] {
+; FNATTRS-SAME: (ptr readnone returned captures(ret: address, provenance) [[Q:%.*]]) #[[ATTR0:[0-9]+]] {
 ; FNATTRS-NEXT:    ret ptr [[Q]]
 ;
 ; ATTRIBUTOR: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
@@ -512,7 +512,7 @@ define void @test4_1(ptr %x4_1, i1 %c) {
 define ptr @test4_2(ptr %x4_2, ptr %y4_2, ptr %z4_2, i1 %c) {
 ; FNATTRS: Function Attrs: nofree nosync nounwind memory(write, argmem: none, inaccessiblemem: none)
 ; FNATTRS-LABEL: define ptr @test4_2
-; FNATTRS-SAME: (ptr readnone captures(none) [[X4_2:%.*]], ptr readnone returned [[Y4_2:%.*]], ptr readnone captures(none) [[Z4_2:%.*]], i1 [[C:%.*]]) #[[ATTR10]] {
+; FNATTRS-SAME: (ptr readnone captures(none) [[X4_2:%.*]], ptr readnone returned captures(ret: address, provenance) [[Y4_2:%.*]], ptr readnone captures(none) [[Z4_2:%.*]], i1 [[C:%.*]]) #[[ATTR10]] {
 ; FNATTRS-NEXT:    br i1 [[C]], label [[T:%.*]], label [[F:%.*]]
 ; FNATTRS:       t:
 ; FNATTRS-NEXT:    call void @test4_1(ptr null, i1 [[C]])
@@ -740,7 +740,7 @@ define void @captureStrip(ptr %p) {
 define i1 @captureICmp(ptr %x) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define i1 @captureICmp
-; FNATTRS-SAME: (ptr readnone [[X:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: (ptr readnone captures(address_is_null) [[X:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[TMP1:%.*]] = icmp eq ptr [[X]], null
 ; FNATTRS-NEXT:    ret i1 [[TMP1]]
 ;
@@ -757,7 +757,7 @@ define i1 @captureICmp(ptr %x) {
 define i1 @captureICmpRev(ptr %x) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define i1 @captureICmpRev
-; FNATTRS-SAME: (ptr readnone [[X:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: (ptr readnone captures(address_is_null) [[X:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[TMP1:%.*]] = icmp eq ptr null, [[X]]
 ; FNATTRS-NEXT:    ret i1 [[TMP1]]
 ;
@@ -774,7 +774,7 @@ define i1 @captureICmpRev(ptr %x) {
 define i1 @nocaptureInboundsGEPICmp(ptr %x) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define i1 @nocaptureInboundsGEPICmp
-; FNATTRS-SAME: (ptr readnone [[X:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: (ptr readnone captures(address_is_null) [[X:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 5
 ; FNATTRS-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP1]], null
 ; FNATTRS-NEXT:    ret i1 [[TMP2]]
@@ -794,7 +794,7 @@ define i1 @nocaptureInboundsGEPICmp(ptr %x) {
 define i1 @nocaptureInboundsGEPICmpRev(ptr %x) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define i1 @nocaptureInboundsGEPICmpRev
-; FNATTRS-SAME: (ptr readnone [[X:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: (ptr readnone captures(address_is_null) [[X:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[X]], i32 5
 ; FNATTRS-NEXT:    [[TMP2:%.*]] = icmp eq ptr null, [[TMP1]]
 ; FNATTRS-NEXT:    ret i1 [[TMP2]]
@@ -831,7 +831,7 @@ define i1 @nocaptureDereferenceableOrNullICmp(ptr dereferenceable_or_null(4) %x)
 define i1 @captureDereferenceableOrNullICmp(ptr dereferenceable_or_null(4) %x) null_pointer_is_valid {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind null_pointer_is_valid willreturn memory(none)
 ; FNATTRS-LABEL: define noundef i1 @captureDereferenceableOrNullICmp
-; FNATTRS-SAME: (ptr readnone dereferenceable_or_null(4) [[X:%.*]]) #[[ATTR16:[0-9]+]] {
+; FNATTRS-SAME: (ptr readnone captures(address_is_null) dereferenceable_or_null(4) [[X:%.*]]) #[[ATTR16:[0-9]+]] {
 ; FNATTRS-NEXT:    [[TMP1:%.*]] = icmp eq ptr [[X]], null
 ; FNATTRS-NEXT:    ret i1 [[TMP1]]
 ;
@@ -903,7 +903,7 @@ define void @readnone_indirec(ptr %f, ptr %p) {
 define ptr @captures_ret_only(ptr %p) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define ptr @captures_ret_only
-; FNATTRS-SAME: (ptr readnone [[P:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: (ptr readnone captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[P]], i64 8
 ; FNATTRS-NEXT:    ret ptr [[GEP]]
 ;
@@ -917,6 +917,8 @@ define ptr @captures_ret_only(ptr %p) {
   ret ptr %gep
 }
 
+; Even though the ptrtoint is only used in the return value, this should *not*
+; be considered a read-only capture.
 define i64 @captures_not_ret_only(ptr %p) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define i64 @captures_not_ret_only
@@ -935,35 +937,52 @@ define i64 @captures_not_ret_only(ptr %p) {
 }
 
 define void @captures_read_provenance(ptr %p) {
-; COMMON-LABEL: define void @captures_read_provenance
-; COMMON-SAME: (ptr [[P:%.*]]) {
-; COMMON-NEXT:    call void @capture(ptr captures(address, read_provenance) [[P]])
-; COMMON-NEXT:    ret void
+; FNATTRS-LABEL: define void @captures_read_provenance
+; FNATTRS-SAME: (ptr captures(address, read_provenance) [[P:%.*]]) {
+; FNATTRS-NEXT:    call void @capture(ptr captures(address, read_provenance) [[P]])
+; FNATTRS-NEXT:    ret void
+;
+; ATTRIBUTOR-LABEL: define void @captures_read_provenance
+; ATTRIBUTOR-SAME: (ptr [[P:%.*]]) {
+; ATTRIBUTOR-NEXT:    call void @capture(ptr captures(address, read_provenance) [[P]])
+; ATTRIBUTOR-NEXT:    ret void
 ;
   call void @capture(ptr captures(address, read_provenance) %p)
   ret void
 }
 
 define void @captures_unused_ret(ptr %p) {
-; COMMON-LABEL: define void @captures_unused_ret
-; COMMON-SAME: (ptr [[P:%.*]]) {
-; COMMON-NEXT:    [[TMP1:%.*]] = call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) [[P]])
-; COMMON-NEXT:    ret void
+; FNATTRS-LABEL: define void @captures_unused_ret
+; FNATTRS-SAME: (ptr captures(address_is_null) [[P:%.*]]) {
+; FNATTRS-NEXT:    [[TMP1:%.*]] = call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) [[P]])
+; FNATTRS-NEXT:    ret void
+;
+; ATTRIBUTOR-LABEL: define void @captures_unused_ret
+; ATTRIBUTOR-SAME: (ptr [[P:%.*]]) {
+; ATTRIBUTOR-NEXT:    [[TMP1:%.*]] = call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) [[P]])
+; ATTRIBUTOR-NEXT:    ret void
 ;
   call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) %p)
   ret void
 }
 
 define ptr @captures_used_ret(ptr %p) {
-; COMMON-LABEL: define ptr @captures_used_ret
-; COMMON-SAME: (ptr [[P:%.*]]) {
-; COMMON-NEXT:    [[RET:%.*]] = call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) [[P]])
-; COMMON-NEXT:    ret ptr [[RET]]
+; FNATTRS-LABEL: define ptr @captures_used_ret
+; FNATTRS-SAME: (ptr captures(address_is_null, ret: address, provenance) [[P:%.*]]) {
+; FNATTRS-NEXT:    [[RET:%.*]] = call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) [[P]])
+; FNATTRS-NEXT:    ret ptr [[RET]]
+;
+; ATTRIBUTOR-LABEL: define ptr @captures_used_ret
+; ATTRIBUTOR-SAME: (ptr [[P:%.*]]) {
+; ATTRIBUTOR-NEXT:    [[RET:%.*]] = call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) [[P]])
+; ATTRIBUTOR-NEXT:    ret ptr [[RET]]
 ;
   %ret = call ptr @capture(ptr captures(address_is_null, ret: address, read_provenance) %p)
   ret ptr %ret
 }
 
+; Make sure this is does not produce captures(ret: ...). We need to take the
+; return capture components into account when handling argument SCCs.
 define ptr @scc_capture_via_ret(i1 %c, ptr %p) {
 ; FNATTRS: Function Attrs: nofree nosync nounwind memory(write, argmem: none, inaccessiblemem: none)
 ; FNATTRS-LABEL: define ptr @scc_capture_via_ret
diff --git a/llvm/test/Transforms/FunctionAttrs/nonnull.ll b/llvm/test/Transforms/FunctionAttrs/nonnull.ll
index 0f6762f0d43426f..94093568419afa4 100644
--- a/llvm/test/Transforms/FunctionAttrs/nonnull.ll
+++ b/llvm/test/Transforms/FunctionAttrs/nonnull.ll
@@ -19,7 +19,7 @@ define ptr @test1() {
 ; Return a pointer trivially nonnull (argument attribute)
 define ptr @test2(ptr nonnull %p) {
 ; FNATTRS-LABEL: define nonnull ptr @test2(
-; FNATTRS-SAME: ptr nonnull readnone returned [[P:%.*]]) #[[ATTR0:[0-9]+]] {
+; FNATTRS-SAME: ptr nonnull readnone returned captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0:[0-9]+]] {
 ; FNATTRS-NEXT:    ret ptr [[P]]
 ;
 ; ATTRIBUTOR-LABEL: define nonnull ptr @test2(
@@ -194,7 +194,7 @@ exit:
 
 define ptr @test7(ptr %a) {
 ; FNATTRS-LABEL: define ptr @test7(
-; FNATTRS-SAME: ptr readnone returned [[A:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: ptr readnone returned captures(ret: address, provenance) [[A:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    ret ptr [[A]]
 ;
 ; ATTRIBUTOR-LABEL: define ptr @test7(
@@ -206,7 +206,7 @@ define ptr @test7(ptr %a) {
 
 define ptr @test8(ptr %a) {
 ; FNATTRS-LABEL: define nonnull ptr @test8(
-; FNATTRS-SAME: ptr readnone [[A:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: ptr readnone captures(ret: address, provenance) [[A:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[B:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 1
 ; FNATTRS-NEXT:    ret ptr [[B]]
 ;
@@ -221,7 +221,7 @@ define ptr @test8(ptr %a) {
 
 define ptr @test9(ptr %a, i64 %n) {
 ; FNATTRS-LABEL: define ptr @test9(
-; FNATTRS-SAME: ptr readnone [[A:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: ptr readnone captures(ret: address, provenance) [[A:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[B:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[N]]
 ; FNATTRS-NEXT:    ret ptr [[B]]
 ;
@@ -238,7 +238,7 @@ declare void @llvm.assume(i1)
 ; FIXME: missing nonnull
 define ptr @test10(ptr %a, i64 %n) {
 ; FNATTRS-LABEL: define ptr @test10(
-; FNATTRS-SAME: ptr readnone [[A:%.*]], i64 [[N:%.*]]) #[[ATTR3:[0-9]+]] {
+; FNATTRS-SAME: ptr readnone captures(ret: address, provenance) [[A:%.*]], i64 [[N:%.*]]) #[[ATTR3:[0-9]+]] {
 ; FNATTRS-NEXT:    [[CMP:%.*]] = icmp ne i64 [[N]], 0
 ; FNATTRS-NEXT:    call void @llvm.assume(i1 [[CMP]])
 ; FNATTRS-NEXT:    [[B:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[N]]
@@ -263,7 +263,7 @@ define ptr @test10(ptr %a, i64 %n) {
 ; }
 define ptr @test11(ptr) local_unnamed_addr {
 ; FNATTRS-LABEL: define nonnull ptr @test11(
-; FNATTRS-SAME: ptr readnone [[TMP0:%.*]]) local_unnamed_addr {
+; FNATTRS-SAME: ptr readnone captures(address_is_null, ret: address, provenance) [[TMP0:%.*]]) local_unnamed_addr {
 ; FNATTRS-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP0]], null
 ; FNATTRS-NEXT:    br i1 [[TMP2]], label [[TMP3:%.*]], label [[TMP5:%.*]]
 ; FNATTRS:       3:
@@ -362,7 +362,7 @@ declare nonnull ptr @nonnull()
 define internal ptr @f1(ptr %arg) {
 ; FIXME: missing nonnull It should be nonnull @f1(ptr nonnull readonly %arg)
 ; FNATTRS-LABEL: define internal nonnull ptr @f1(
-; FNATTRS-SAME: ptr readonly [[ARG:%.*]]) #[[ATTR4:[0-9]+]] {
+; FNATTRS-SAME: ptr readonly captures(address_is_null) [[ARG:%.*]]) #[[ATTR4:[0-9]+]] {
 ; FNATTRS-NEXT:  bb:
 ; FNATTRS-NEXT:    [[TMP:%.*]] = icmp eq ptr [[ARG]], null
 ; FNATTRS-NEXT:    br i1 [[TMP]], label [[BB9:%.*]], label [[BB1:%.*]]
@@ -431,7 +431,7 @@ bb9:                                              ; preds = %bb4, %bb
 define internal ptr @f2(ptr %arg) {
 ; FIXME: missing nonnull. It should be nonnull @f2(ptr nonnull %arg)
 ; FNATTRS-LABEL: define internal nonnull ptr @f2(
-; FNATTRS-SAME: ptr [[ARG:%.*]]) #[[ATTR4]] {
+; FNATTRS-SAME: ptr readonly captures(address_is_null) [[ARG:%.*]]) #[[ATTR4]] {
 ; FNATTRS-NEXT:  bb:
 ; FNATTRS-NEXT:    [[TMP:%.*]] = tail call ptr @f1(ptr [[ARG]])
 ; FNATTRS-NEXT:    ret ptr [[TMP]]
@@ -452,7 +452,7 @@ bb:
 define dso_local noalias ptr @f3(ptr %arg) {
 ; FIXME: missing nonnull. It should be nonnull @f3(ptr nonnull readonly %arg)
 ; FNATTRS-LABEL: define dso_local noalias nonnull ptr @f3(
-; FNATTRS-SAME: ptr [[ARG:%.*]]) #[[ATTR4]] {
+; FNATTRS-SAME: ptr readonly captures(address_is_null) [[ARG:%.*]]) #[[ATTR4]] {
 ; FNATTRS-NEXT:  bb:
 ; FNATTRS-NEXT:    [[TMP:%.*]] = call ptr @f1(ptr [[ARG]])
 ; FNATTRS-NEXT:    ret ptr [[TMP]]
@@ -945,7 +945,7 @@ exc:
 
 define ptr @gep1(ptr %p) {
 ; FNATTRS-LABEL: define nonnull ptr @gep1(
-; FNATTRS-SAME: ptr readnone [[P:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: ptr readnone captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[Q:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 1
 ; FNATTRS-NEXT:    ret ptr [[Q]]
 ;
@@ -961,7 +961,7 @@ define ptr @gep1(ptr %p) {
 define ptr @gep1_no_null_opt(ptr %p) #0 {
 ; Should't be able to derive nonnull based on gep.
 ; FNATTRS-LABEL: define ptr @gep1_no_null_opt(
-; FNATTRS-SAME: ptr readnone [[P:%.*]]) #[[ATTR8:[0-9]+]] {
+; FNATTRS-SAME: ptr readnone captures(ret: address, provenance) [[P:%.*]]) #[[ATTR8:[0-9]+]] {
 ; FNATTRS-NEXT:    [[Q:%.*]] = getelementptr inbounds i32, ptr [[P]], i32 1
 ; FNATTRS-NEXT:    ret ptr [[Q]]
 ;
@@ -976,7 +976,7 @@ define ptr @gep1_no_null_opt(ptr %p) #0 {
 
 define ptr addrspace(3) @gep2(ptr addrspace(3) %p) {
 ; FNATTRS-LABEL: define ptr addrspace(3) @gep2(
-; FNATTRS-SAME: ptr addrspace(3) readnone [[P:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: ptr addrspace(3) readnone captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[Q:%.*]] = getelementptr inbounds i32, ptr addrspace(3) [[P]], i32 1
 ; FNATTRS-NEXT:    ret ptr addrspace(3) [[Q]]
 ;
@@ -992,7 +992,7 @@ define ptr addrspace(3) @gep2(ptr addrspace(3) %p) {
 ; FIXME: We should propagate dereferenceable here but *not* nonnull
 define ptr addrspace(3) @as(ptr addrspace(3) dereferenceable(4) %p) {
 ; FNATTRS-LABEL: define noundef ptr addrspace(3) @as(
-; FNATTRS-SAME: ptr addrspace(3) readnone returned dereferenceable(4) [[P:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: ptr addrspace(3) readnone returned captures(ret: address, provenance) dereferenceable(4) [[P:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    ret ptr addrspace(3) [[P]]
 ;
 ; ATTRIBUTOR-LABEL: define ptr addrspace(3) @as(
@@ -1383,7 +1383,7 @@ define void @PR43833_simple(ptr %0, i32 %1) {
 
 define ptr @pr91177_non_inbounds_gep(ptr nonnull %arg) {
 ; FNATTRS-LABEL: define ptr @pr91177_non_inbounds_gep(
-; FNATTRS-SAME: ptr nonnull readnone [[ARG:%.*]]) #[[ATTR0]] {
+; FNATTRS-SAME: ptr nonnull readnone captures(ret: address, provenance) [[ARG:%.*]]) #[[ATTR0]] {
 ; FNATTRS-NEXT:    [[RES:%.*]] = getelementptr i8, ptr [[ARG]], i64 -8
 ; FNATTRS-NEXT:    ret ptr [[RES]]
 ;
diff --git a/llvm/test/Transforms/FunctionAttrs/noundef.ll b/llvm/test/Transforms/FunctionAttrs/noundef.ll
index b7c583880501a2e..4f53c0880462100 100644
--- a/llvm/test/Transforms/FunctionAttrs/noundef.ll
+++ b/llvm/test/Transforms/FunctionAttrs/noundef.ll
@@ -169,7 +169,7 @@ define i64 @test_trunc_with_constexpr() {
 
 define align 4 ptr @maybe_not_aligned(ptr noundef %p) {
 ; CHECK-LABEL: define align 4 ptr @maybe_not_aligned(
-; CHECK-SAME: ptr noundef readnone returned [[P:%.*]]) #[[ATTR0]] {
+; CHECK-SAME: ptr noundef readnone returned captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    ret ptr [[P]]
 ;
   ret ptr %p
@@ -177,7 +177,7 @@ define align 4 ptr @maybe_not_aligned(ptr noundef %p) {
 
 define align 4 ptr @definitely_aligned(ptr noundef align 4 %p) {
 ; CHECK-LABEL: define noundef align 4 ptr @definitely_aligned(
-; CHECK-SAME: ptr noundef readnone returned align 4 [[P:%.*]]) #[[ATTR0]] {
+; CHECK-SAME: ptr noundef readnone returned align 4 captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    ret ptr [[P]]
 ;
   ret ptr %p
@@ -185,7 +185,7 @@ define align 4 ptr @definitely_aligned(ptr noundef align 4 %p) {
 
 define nonnull ptr @maybe_not_nonnull(ptr noundef %p) {
 ; CHECK-LABEL: define nonnull ptr @maybe_not_nonnull(
-; CHECK-SAME: ptr noundef readnone returned [[P:%.*]]) #[[ATTR0]] {
+; CHECK-SAME: ptr noundef readnone returned captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    ret ptr [[P]]
 ;
   ret ptr %p
@@ -193,7 +193,7 @@ define nonnull ptr @maybe_not_nonnull(ptr noundef %p) {
 
 define nonnull ptr @definitely_nonnull(ptr noundef nonnull %p) {
 ; CHECK-LABEL: define noundef nonnull ptr @definitely_nonnull(
-; CHECK-SAME: ptr noundef nonnull readnone returned [[P:%.*]]) #[[ATTR0]] {
+; CHECK-SAME: ptr noundef nonnull readnone returned captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    ret ptr [[P]]
 ;
   ret ptr %p
diff --git a/llvm/test/Transforms/FunctionAttrs/readattrs.ll b/llvm/test/Transforms/FunctionAttrs/readattrs.ll
index b24c097ad54d08c..5fc88d623c0ec7a 100644
--- a/llvm/test/Transforms/FunctionAttrs/readattrs.ll
+++ b/llvm/test/Transforms/FunctionAttrs/readattrs.ll
@@ -35,7 +35,7 @@ define void @test1_2(ptr %x1_2, ptr %y1_2, ptr %z1_2) {
 define ptr @test2(ptr %p) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(write, argmem: none, inaccessiblemem: none)
 ; FNATTRS-LABEL: define {{[^@]+}}@test2
-; FNATTRS-SAME: (ptr readnone returned [[P:%.*]]) #[[ATTR0:[0-9]+]] {
+; FNATTRS-SAME: (ptr readnone returned captures(ret: address, provenance) [[P:%.*]]) #[[ATTR0:[0-9]+]] {
 ; FNATTRS-NEXT:    store i32 0, ptr @x, align 4
 ; FNATTRS-NEXT:    ret ptr [[P]]
 ;
@@ -58,7 +58,7 @@ define ptr @test2(ptr %p) {
 define i1 @test3(ptr %p, ptr %q) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define {{[^@]+}}@test3
-; FNATTRS-SAME: (ptr readnone [[P:%.*]], ptr readnone [[Q:%.*]]) #[[ATTR1:[0-9]+]] {
+; FNATTRS-SAME: (ptr readnone captures(address) [[P:%.*]], ptr readnone captures(address) [[Q:%.*]]) #[[ATTR1:[0-9]+]] {
 ; FNATTRS-NEXT:    [[A:%.*]] = icmp ult ptr [[P]], [[Q]]
 ; FNATTRS-NEXT:    ret i1 [[A]]
 ;
@@ -197,7 +197,7 @@ define void @test7_2(ptr preallocated(i32) %a) {
 define ptr @test8_1(ptr %p) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define {{[^@]+}}@test8_1
-; FNATTRS-SAME: (ptr readnone returned [[P:%.*]]) #[[ATTR1]] {
+; FNATTRS-SAME: (ptr readnone returned captures(ret: address, provenance) [[P:%.*]]) #[[ATTR1]] {
 ; FNATTRS-NEXT:  entry:
 ; FNATTRS-NEXT:    ret ptr [[P]]
 ;
@@ -220,7 +220,7 @@ entry:
 define void @test8_2(ptr %p) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
 ; FNATTRS-LABEL: define {{[^@]+}}@test8_2
-; FNATTRS-SAME: (ptr writeonly [[P:%.*]]) #[[ATTR4]] {
+; FNATTRS-SAME: (ptr writeonly captures(none) [[P:%.*]]) #[[ATTR4]] {
 ; FNATTRS-NEXT:  entry:
 ; FNATTRS-NEXT:    [[CALL:%.*]] = call ptr @test8_1(ptr [[P]])
 ; FNATTRS-NEXT:    store i32 10, ptr [[CALL]], align 4
diff --git a/llvm/test/Transforms/FunctionAttrs/stats.ll b/llvm/test/Transforms/FunctionAttrs/stats.ll
index 5f007b4078ff3dd..dc0387e57174a92 100644
--- a/llvm/test/Transforms/FunctionAttrs/stats.ll
+++ b/llvm/test/Transforms/FunctionAttrs/stats.ll
@@ -16,8 +16,8 @@ entry:
   ret void
 }
 
-; CHECK:      2 function-attrs - Number of functions with improved memory attribute
-; CHECK-NEXT: 1 function-attrs - Number of arguments marked nocapture
+; CHECK:      1 function-attrs - Number of arguments marked captures(none)
+; CHECK-NEXT: 2 function-attrs - Number of functions with improved memory attribute
 ; CHECK-NEXT: 1 function-attrs - Number of functions marked as nofree
 ; CHECK-NEXT: 2 function-attrs - Number of functions marked as norecurse
 ; CHECK-NEXT: 2 function-attrs - Number of functions marked as nosync
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/block_scaling_decompr_8bit.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/block_scaling_decompr_8bit.ll
index e01dba328a3a1b0..7175816963ed145 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/block_scaling_decompr_8bit.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/block_scaling_decompr_8bit.ll
@@ -9,7 +9,7 @@ target triple = "aarch64"
 
 define dso_local noundef i32 @_Z33block_scaling_decompr_8bitjPK27compressed_data_8bitP20cmplx_int16_tPKS2_(i32 noundef %n_prb, ptr noundef %src, ptr noundef %dst, ptr noundef %scale) #0 {
 ; CHECK-LABEL: define dso_local noundef i32 @_Z33block_scaling_decompr_8bitjPK27compressed_data_8bitP20cmplx_int16_tPKS2_(
-; CHECK-SAME: i32 noundef [[N_PRB:%.*]], ptr noundef readonly captures(none) [[SRC:%.*]], ptr noundef writeonly captures(none) [[DST:%.*]], ptr noundef readonly [[SCALE:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; CHECK-SAME: i32 noundef [[N_PRB:%.*]], ptr noundef readonly captures(none) [[SRC:%.*]], ptr noundef writeonly captures(none) [[DST:%.*]], ptr noundef readonly captures(address_is_null) [[SCALE:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
 ; CHECK-NEXT:    [[CMP47_NOT:%.*]] = icmp eq i32 [[N_PRB]], 0
 ; CHECK-NEXT:    br i1 [[CMP47_NOT]], label %[[FOR_END:.*]], label %[[FOR_BODY_LR_PH:.*]]
diff --git a/llvm/test/Transforms/PhaseOrdering/bitcast-store-branch.ll b/llvm/test/Transforms/PhaseOrdering/bitcast-store-branch.ll
index bbd4849c32296e1..d5edf83ee52e2fc 100644
--- a/llvm/test/Transforms/PhaseOrdering/bitcast-store-branch.ll
+++ b/llvm/test/Transforms/PhaseOrdering/bitcast-store-branch.ll
@@ -12,7 +12,7 @@ entry:
 
 define ptr @parent(ptr align 8 dereferenceable(72) %f, half %val1, i16 %val2, i32 %val3) align 2 {
 ; CHECK-LABEL: define noundef nonnull ptr @parent
-; CHECK-SAME: (ptr readonly returned align 8 dereferenceable(72) [[F:%.*]], half [[VAL1:%.*]], i16 [[VAL2:%.*]], i32 [[VAL3:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] align 2 {
+; CHECK-SAME: (ptr readonly returned align 8 captures(ret: address, provenance) dereferenceable(72) [[F:%.*]], half [[VAL1:%.*]], i16 [[VAL2:%.*]], i32 [[VAL3:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] align 2 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds nuw i8, ptr [[F]], i64 64
 ; CHECK-NEXT:    [[F_VAL:%.*]] = load ptr, ptr [[TMP0]], align 8
diff --git a/llvm/test/Transforms/PhaseOrdering/dce-after-argument-promotion-loads.ll b/llvm/test/Transforms/PhaseOrdering/dce-after-argument-promotion-loads.ll
index ee7698b116aa237..4b422f205138af9 100644
--- a/llvm/test/Transforms/PhaseOrdering/dce-after-argument-promotion-loads.ll
+++ b/llvm/test/Transforms/PhaseOrdering/dce-after-argument-promotion-loads.ll
@@ -14,7 +14,7 @@ entry:
 
 define ptr @parent(ptr align 8 dereferenceable(72) %f, i16 %val1, i16 %val2, i32 %val3) align 2 {
 ; CHECK-LABEL: define noundef nonnull ptr @parent
-; CHECK-SAME: (ptr readonly returned align 8 dereferenceable(72) [[F:%.*]], i16 [[VAL1:%.*]], i16 [[VAL2:%.*]], i32 [[VAL3:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] align 2 {
+; CHECK-SAME: (ptr readonly returned align 8 captures(ret: address, provenance) dereferenceable(72) [[F:%.*]], i16 [[VAL1:%.*]], i16 [[VAL2:%.*]], i32 [[VAL3:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] align 2 {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds nuw i8, ptr [[F]], i64 64
 ; CHECK-NEXT:    [[F_VAL:%.*]] = load ptr, ptr [[TMP0]], align 8
diff --git a/llvm/test/Transforms/PhaseOrdering/enable-loop-header-duplication-oz.ll b/llvm/test/Transforms/PhaseOrdering/enable-loop-header-duplication-oz.ll
index 5f75bd788e4bb46..cd2ed37b22db512 100644
--- a/llvm/test/Transforms/PhaseOrdering/enable-loop-header-duplication-oz.ll
+++ b/llvm/test/Transforms/PhaseOrdering/enable-loop-header-duplication-oz.ll
@@ -11,7 +11,7 @@
 
 define void @test(i8* noalias nonnull align 1 %start, i8* %end) unnamed_addr {
 ; NOROTATION-LABEL: define void @test(
-; NOROTATION-SAME: ptr noalias nonnull writeonly align 1 [[START:%.*]], ptr readnone [[END:%.*]]) unnamed_addr #[[ATTR0:[0-9]+]] {
+; NOROTATION-SAME: ptr noalias nonnull writeonly align 1 captures(address) [[START:%.*]], ptr readnone captures(address) [[END:%.*]]) unnamed_addr #[[ATTR0:[0-9]+]] {
 ; NOROTATION-NEXT:  entry:
 ; NOROTATION-NEXT:    br label [[LOOP_HEADER:%.*]]
 ; NOROTATION:       loop.header:
@@ -26,7 +26,7 @@ define void @test(i8* noalias nonnull align 1 %start, i8* %end) unnamed_addr {
 ; NOROTATION-NEXT:    ret void
 ;
 ; ROTATION-LABEL: define void @test(
-; ROTATION-SAME: ptr noalias nonnull writeonly align 1 [[START:%.*]], ptr readnone [[END:%.*]]) unnamed_addr #[[ATTR0:[0-9]+]] {
+; ROTATION-SAME: ptr noalias nonnull writeonly align 1 captures(address) [[START:%.*]], ptr readnone captures(address) [[END:%.*]]) unnamed_addr #[[ATTR0:[0-9]+]] {
 ; ROTATION-NEXT:  entry:
 ; ROTATION-NEXT:    [[_12_I1:%.*]] = icmp eq ptr [[START]], [[END]]
 ; ROTATION-NEXT:    br i1 [[_12_I1]], label [[EXIT:%.*]], label [[LOOP_LATCH_PREHEADER:%.*]]
diff --git a/llvm/unittests/Analysis/CaptureTrackingTest.cpp b/llvm/unittests/Analysis/CaptureTrackingTest.cpp
index 73dd82fb921f761..8e677bb04ec32cf 100644
--- a/llvm/unittests/Analysis/CaptureTrackingTest.cpp
+++ b/llvm/unittests/Analysis/CaptureTrackingTest.cpp
@@ -77,9 +77,10 @@ TEST(CaptureTracking, MaxUsesToExplore) {
 struct CollectingCaptureTracker : public CaptureTracker {
   SmallVector<const Use *, 4> Captures;
   void tooManyUses() override { }
-  bool captured(const Use *U) override {
+  std::optional<CaptureComponents> captured(const Use *U,
+                                            CaptureInfo CI) override {
     Captures.push_back(U);
-    return false;
+    return continueDefault(CI);
   }
 };
 

>From 28d6ce7477b8ddb3262be269f1bbc8065ae5bc9d Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Fri, 7 Feb 2025 09:55:38 +0100
Subject: [PATCH 2/8] Address review

---
 llvm/include/llvm/Analysis/CaptureTracking.h  | 48 +++++++------------
 llvm/lib/Analysis/CaptureTracking.cpp         | 44 +++++++++--------
 llvm/lib/Analysis/InstructionSimplify.cpp     |  7 ++-
 llvm/lib/Transforms/IPO/FunctionAttrs.cpp     | 20 ++++----
 .../InstCombine/InstCombineCompares.cpp       |  7 ++-
 5 files changed, 58 insertions(+), 68 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CaptureTracking.h b/llvm/include/llvm/Analysis/CaptureTracking.h
index 4e09c1e6f3021dc..b1eea6bcc07da26 100644
--- a/llvm/include/llvm/Analysis/CaptureTracking.h
+++ b/llvm/include/llvm/Analysis/CaptureTracking.h
@@ -14,7 +14,6 @@
 #define LLVM_ANALYSIS_CAPTURETRACKING_H
 
 #include "llvm/ADT/DenseMap.h"
-#include "llvm/Support/ModRef.h"
 
 namespace llvm {
 
@@ -83,6 +82,21 @@ namespace llvm {
   /// addition to the interface here, you'll need to provide your own getters
   /// to see whether anything was captured.
   struct CaptureTracker {
+    /// Action returned from captures().
+    enum Action {
+      /// Stop the traversal.
+      Stop,
+      /// Continue traversal, and also follow the return value of the user if
+      /// it has additional capture components (that is, if it has capture
+      /// components in Ret that are not part of Other).
+      Continue,
+      /// Continue traversal, but do not follow the return value of the user,
+      /// even if it has additional capture components. Should only be used if
+      /// captures() has already taken the potential return captures into
+      /// account.
+      ContinueIgnoringReturn,
+    };
+
     virtual ~CaptureTracker();
 
     /// tooManyUses - The depth of traversal has breached a limit. There may be
@@ -96,38 +110,12 @@ namespace llvm {
     /// U->getUser() is always an Instruction.
     virtual bool shouldExplore(const Use *U);
 
-    /// When returned from captures(), stop the traversal.
-    static std::optional<CaptureComponents> stop() { return std::nullopt; }
-
-    /// When returned from captures(), continue traversal, but do not follow
-    /// the return value of this user, even if it has additional capture
-    /// components. Should only be used if captures() has already taken the
-    /// potential return caputres into account.
-    static std::optional<CaptureComponents> continueIgnoringReturn() {
-      return CaptureComponents::None;
-    }
-
-    /// When returned from captures(), continue traversal, and also follow
-    /// the return value of this user if it has additional capture components
-    /// (that is, capture components in Ret that are not part of Other).
-    static std::optional<CaptureComponents> continueDefault(CaptureInfo CI) {
-      CaptureComponents RetCC = CI.getRetComponents();
-      if (!capturesNothing(RetCC & ~CI.getOtherComponents()))
-        return RetCC;
-      return CaptureComponents::None;
-    }
-
     /// Use U directly captures CI.getOtherComponents() and additionally
     /// CI.getRetComponents() through the return value of the user of U.
     ///
-    /// Return std::nullopt to stop the traversal, or the CaptureComponents to
-    /// follow via the return value, which must be a subset of
-    /// CI.getRetComponents().
-    ///
-    /// For convenience, prefer returning one of stop(), continueDefault(CI) or
-    /// continueIgnoringReturn().
-    virtual std::optional<CaptureComponents> captured(const Use *U,
-                                                      CaptureInfo CI) = 0;
+    /// Return one of Stop, Continue or ContinueIgnoringReturn to control
+    /// further traversal.
+    virtual Action captured(const Use *U, CaptureInfo CI) = 0;
 
     /// isDereferenceableOrNull - Overload to allow clients with additional
     /// knowledge about pointer dereferenceability to provide it and thereby
diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp
index 4e403b8825c7f71..98bb2d80908ae13 100644
--- a/llvm/lib/Analysis/CaptureTracking.cpp
+++ b/llvm/lib/Analysis/CaptureTracking.cpp
@@ -81,16 +81,15 @@ struct SimpleCaptureTracker : public CaptureTracker {
     Captured = true;
   }
 
-  std::optional<CaptureComponents> captured(const Use *U,
-                                            CaptureInfo CI) override {
+  Action captured(const Use *U, CaptureInfo CI) override {
     // TODO(captures): Use CaptureInfo.
     if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures)
-      return continueIgnoringReturn();
+      return ContinueIgnoringReturn;
 
     LLVM_DEBUG(dbgs() << "Captured by: " << *U->getUser() << "\n");
 
     Captured = true;
-    return stop();
+    return Stop;
   }
 
   bool ReturnCaptures;
@@ -124,22 +123,21 @@ struct CapturesBefore : public CaptureTracker {
     return !isPotentiallyReachable(I, BeforeHere, nullptr, DT, LI);
   }
 
-  std::optional<CaptureComponents> captured(const Use *U,
-                                            CaptureInfo CI) override {
+  Action captured(const Use *U, CaptureInfo CI) override {
     // TODO(captures): Use CaptureInfo.
     Instruction *I = cast<Instruction>(U->getUser());
     if (isa<ReturnInst>(I) && !ReturnCaptures)
-      return continueIgnoringReturn();
+      return ContinueIgnoringReturn;
 
     // Check isSafeToPrune() here rather than in shouldExplore() to avoid
     // an expensive reachability query for every instruction we look at.
     // Instead we only do one for actual capturing candidates.
     if (isSafeToPrune(I))
       // If the use is not reachable, the instruction result isn't either.
-      return continueIgnoringReturn();
+      return ContinueIgnoringReturn;
 
     Captured = true;
-    return stop();
+    return Stop;
   }
 
   const Instruction *BeforeHere;
@@ -171,12 +169,11 @@ struct EarliestCaptures : public CaptureTracker {
     EarliestCapture = &*F.getEntryBlock().begin();
   }
 
-  std::optional<CaptureComponents> captured(const Use *U,
-                                            CaptureInfo CI) override {
+  Action captured(const Use *U, CaptureInfo CI) override {
     // TODO(captures): Use CaptureInfo.
     Instruction *I = cast<Instruction>(U->getUser());
     if (isa<ReturnInst>(I) && !ReturnCaptures)
-      return continueIgnoringReturn();
+      return ContinueIgnoringReturn;
 
     if (!EarliestCapture)
       EarliestCapture = I;
@@ -187,7 +184,7 @@ struct EarliestCaptures : public CaptureTracker {
     // Continue analysis, as we need to see all potential captures. However,
     // we do not need to follow the instruction result, as this use will
     // dominate any captures made through the instruction result..
-    return continueIgnoringReturn();
+    return ContinueIgnoringReturn;
   }
 
   Instruction *EarliestCapture = nullptr;
@@ -451,17 +448,24 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker,
     CaptureInfo CI = DetermineUseCaptureKind(*U, IsDereferenceableOrNull);
     if (capturesNothing(CI))
       continue;
+    CaptureComponents OtherCC = CI.getOtherComponents();
     CaptureComponents RetCC = CI.getRetComponents();
-    if (!capturesNothing(CI.getOtherComponents())) {
-      std::optional<CaptureComponents> Res = Tracker->captured(U, CI);
-      if (!Res)
+    if (capturesAnything(OtherCC)) {
+      switch (Tracker->captured(U, CI)) {
+      case CaptureTracker::Stop:
         return;
-      assert(capturesNothing(*Res & ~RetCC) &&
-             "captures() result must be subset of getRetComponents()");
-      RetCC = *Res;
+      case CaptureTracker::ContinueIgnoringReturn:
+        continue;
+      case CaptureTracker::Continue:
+        // Fall through to passthrough handling, but only if RetCC contains
+        // additional components that OtherCC does not.
+        if (capturesNothing(RetCC & ~OtherCC))
+          continue;
+        break;
+      }
     }
     // TODO(captures): We could keep track of RetCC for the users.
-    if (!capturesNothing(RetCC) && !AddUses(U->getUser()))
+    if (capturesAnything(RetCC) && !AddUses(U->getUser()))
       return;
   }
 
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 58fa165341f09a3..c030d6bbce35e6b 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -2801,8 +2801,7 @@ static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS,
       struct CustomCaptureTracker : public CaptureTracker {
         bool Captured = false;
         void tooManyUses() override { Captured = true; }
-        std::optional<CaptureComponents> captured(const Use *U,
-                                                  CaptureInfo CI) override {
+        Action captured(const Use *U, CaptureInfo CI) override {
           // TODO(captures): Use CaptureInfo.
           if (auto *ICmp = dyn_cast<ICmpInst>(U->getUser())) {
             // Comparison against value stored in global variable. Given the
@@ -2811,11 +2810,11 @@ static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS,
             unsigned OtherIdx = 1 - U->getOperandNo();
             auto *LI = dyn_cast<LoadInst>(ICmp->getOperand(OtherIdx));
             if (LI && isa<GlobalVariable>(LI->getPointerOperand()))
-              return continueDefault(CI);
+              return Continue;
           }
 
           Captured = true;
-          return stop();
+          return Stop;
         }
       };
       CustomCaptureTracker Tracker;
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index 920ccf28b3f1b04..5df3fbcab71e36f 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -72,8 +72,8 @@ using namespace llvm;
 
 STATISTIC(NumMemoryAttr, "Number of functions with improved memory attribute");
 STATISTIC(NumCapturesNone, "Number of arguments marked captures(none)");
-STATISTIC(NumCapturesOther, "Number of arguments marked with captures "
-                            "attribute other than captures(none)");
+STATISTIC(NumCapturesPartial, "Number of arguments marked with captures "
+                              "attribute other than captures(none)");
 STATISTIC(NumReturned, "Number of arguments marked returned");
 STATISTIC(NumReadNoneArg, "Number of arguments marked readnone");
 STATISTIC(NumReadOnlyArg, "Number of arguments marked readonly");
@@ -114,7 +114,7 @@ static void addCapturesStat(CaptureInfo CI) {
   if (capturesNothing(CI))
     ++NumCapturesNone;
   else
-    ++NumCapturesOther;
+    ++NumCapturesPartial;
 }
 
 namespace {
@@ -549,18 +549,17 @@ struct ArgumentUsesTracker : public CaptureTracker {
 
   void tooManyUses() override { CI = CaptureInfo::all(); }
 
-  std::optional<CaptureComponents> captured(const Use *U,
-                                            CaptureInfo UseCI) override {
+  Action captured(const Use *U, CaptureInfo UseCI) override {
     if (updateCaptureInfo(U, UseCI.getOtherComponents())) {
       // Don't bother continuing if we already capture everything.
       if (capturesAll(CI.getOtherComponents()))
-        return stop();
-      return continueDefault(UseCI);
+        return Stop;
+      return Continue;
     }
 
     // For SCC argument tracking, we're not going to analyze other/ret
     // components separately, so don't follow the return value.
-    return continueIgnoringReturn();
+    return ContinueIgnoringReturn;
   }
 
   bool updateCaptureInfo(const Use *U, CaptureComponents CC) {
@@ -1329,7 +1328,8 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
         }
 
         // Infer the access attributes given the new captures one
-        DetermineAccessAttrsForSingleton(A);
+        if (DetermineAccessAttrsForSingleton(A))
+          Changed.insert(A->getParent());
       }
       continue;
     }
@@ -1369,7 +1369,7 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
     }
 
     // TODO(captures): Ignore address-only captures.
-    if (!capturesNothing(CC)) {
+    if (capturesAnything(CC)) {
       // As the pointer may be captured, determine the pointer attributes
       // looking at each argument invidivually.
       for (ArgumentGraphNode *N : ArgumentSCC) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 6640e15f425b90f..23595f9d3197e68 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -882,8 +882,7 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
 
     void tooManyUses() override { Captured = true; }
 
-    std::optional<CaptureComponents> captured(const Use *U,
-                                              CaptureInfo CI) override {
+    Action captured(const Use *U, CaptureInfo CI) override {
       // TODO(captures): Use CaptureInfo.
       auto *ICmp = dyn_cast<ICmpInst>(U->getUser());
       // We need to check that U is based *only* on the alloca, and doesn't
@@ -894,11 +893,11 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
         // Collect equality icmps of the alloca, and don't treat them as
         // captures.
         ICmps[ICmp] |= 1u << U->getOperandNo();
-        return continueDefault(CI);
+        return Continue;
       }
 
       Captured = true;
-      return stop();
+      return Stop;
     }
   };
 

>From 3d53796e75cb293385e5bb958305434a846a802c Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Fri, 7 Feb 2025 10:21:08 +0100
Subject: [PATCH 3/8] Make sure we don't increase pre-existing captures
 attributes

---
 llvm/lib/Transforms/IPO/FunctionAttrs.cpp     | 16 +++--
 .../Transforms/FunctionAttrs/nocapture.ll     | 67 +++++++++++++++++++
 2 files changed, 77 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index 5df3fbcab71e36f..f1302b223bf2999 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -1320,8 +1320,9 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
       if (ArgumentSCC[0]->Uses.size() == 1 &&
           ArgumentSCC[0]->Uses[0] == ArgumentSCC[0]) {
         Argument *A = ArgumentSCC[0]->Definition;
-        CaptureInfo NewCI = ArgumentSCC[0]->CC;
-        if (NewCI != A->getAttributes().getCaptureInfo()) {
+        CaptureInfo OrigCI = A->getAttributes().getCaptureInfo();
+        CaptureInfo NewCI = CaptureInfo(ArgumentSCC[0]->CC) & OrigCI;
+        if (NewCI != OrigCI) {
           A->addAttr(Attribute::getWithCaptureInfo(A->getContext(), NewCI));
           addCapturesStat(NewCI);
           Changed.insert(A->getParent());
@@ -1361,10 +1362,13 @@ static void addArgumentAttrs(const SCCNodeSet &SCCNodes,
     if (!capturesAll(CC)) {
       for (ArgumentGraphNode *N : ArgumentSCC) {
         Argument *A = N->Definition;
-        CaptureInfo CI = N->CC | CC;
-        A->addAttr(Attribute::getWithCaptureInfo(A->getContext(), CI));
-        addCapturesStat(CI);
-        Changed.insert(A->getParent());
+        CaptureInfo OrigCI = A->getAttributes().getCaptureInfo();
+        CaptureInfo NewCI = CaptureInfo(N->CC | CC) & OrigCI;
+        if (NewCI != OrigCI) {
+          A->addAttr(Attribute::getWithCaptureInfo(A->getContext(), NewCI));
+          addCapturesStat(NewCI);
+          Changed.insert(A->getParent());
+        }
       }
     }
 
diff --git a/llvm/test/Transforms/FunctionAttrs/nocapture.ll b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
index aeb950bd41e93cb..ccb5251fb4ea3ed 100644
--- a/llvm/test/Transforms/FunctionAttrs/nocapture.ll
+++ b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
@@ -1018,5 +1018,72 @@ else:
   ret ptr %p
 }
 
+define i1 @improve_existing_captures(ptr captures(address) %p) {
+; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+; FNATTRS-LABEL: define i1 @improve_existing_captures
+; FNATTRS-SAME: (ptr readnone captures(address_is_null) [[P:%.*]]) #[[ATTR0]] {
+; FNATTRS-NEXT:    [[CMP:%.*]] = icmp eq ptr [[P]], null
+; FNATTRS-NEXT:    ret i1 [[CMP]]
+;
+; ATTRIBUTOR: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+; ATTRIBUTOR-LABEL: define i1 @improve_existing_captures
+; ATTRIBUTOR-SAME: (ptr nofree readnone captures(address) [[P:%.*]]) #[[ATTR0]] {
+; ATTRIBUTOR-NEXT:    [[CMP:%.*]] = icmp eq ptr [[P]], null
+; ATTRIBUTOR-NEXT:    ret i1 [[CMP]]
+;
+  %cmp = icmp eq ptr %p, null
+  ret i1 %cmp
+}
+
+define void @dont_increase_existing_captures(ptr captures(address) %p) {
+; COMMON-LABEL: define void @dont_increase_existing_captures
+; COMMON-SAME: (ptr captures(address) [[P:%.*]]) {
+; COMMON-NEXT:    call void @capture(ptr [[P]])
+; COMMON-NEXT:    ret void
+;
+  call void @capture(ptr %p)
+  ret void
+}
+
+define void @dont_increase_existing_captures_trivial_scc(ptr captures(address) %p) {
+; COMMON-LABEL: define void @dont_increase_existing_captures_trivial_scc
+; COMMON-SAME: (ptr captures(address) [[P:%.*]]) {
+; COMMON-NEXT:    call void @capture(ptr captures(address, read_provenance) [[P]])
+; COMMON-NEXT:    call void @dont_increase_existing_captures_trivial_scc(ptr [[P]])
+; COMMON-NEXT:    ret void
+;
+  call void @capture(ptr captures(address, read_provenance) %p)
+  call void @dont_increase_existing_captures_trivial_scc(ptr %p)
+  ret void
+}
+
+define void @dont_increase_existing_captures_scc1(ptr captures(address) %p) {
+; COMMON-LABEL: define void @dont_increase_existing_captures_scc1
+; COMMON-SAME: (ptr captures(address) [[P:%.*]]) {
+; COMMON-NEXT:    call void @dont_increase_existing_captures_scc2(ptr [[P]])
+; COMMON-NEXT:    ret void
+;
+  call void @dont_increase_existing_captures_scc2(ptr %p)
+  ret void
+}
+
+define void @dont_increase_existing_captures_scc2(ptr %p) {
+; FNATTRS-LABEL: define void @dont_increase_existing_captures_scc2
+; FNATTRS-SAME: (ptr captures(address, read_provenance) [[P:%.*]]) {
+; FNATTRS-NEXT:    call void @capture(ptr captures(address, read_provenance) [[P]])
+; FNATTRS-NEXT:    call void @dont_increase_existing_captures_scc1(ptr [[P]])
+; FNATTRS-NEXT:    ret void
+;
+; ATTRIBUTOR-LABEL: define void @dont_increase_existing_captures_scc2
+; ATTRIBUTOR-SAME: (ptr [[P:%.*]]) {
+; ATTRIBUTOR-NEXT:    call void @capture(ptr captures(address, read_provenance) [[P]])
+; ATTRIBUTOR-NEXT:    call void @dont_increase_existing_captures_scc1(ptr [[P]])
+; ATTRIBUTOR-NEXT:    ret void
+;
+  call void @capture(ptr captures(address, read_provenance) %p)
+  call void @dont_increase_existing_captures_scc1(ptr %p)
+  ret void
+}
+
 declare ptr @llvm.launder.invariant.group.p0(ptr)
 declare ptr @llvm.strip.invariant.group.p0(ptr)

>From 262c7bcc426e0304207dff025644c1c886ef48f1 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Fri, 7 Feb 2025 10:25:02 +0100
Subject: [PATCH 4/8] Use capturesAnything in more places

---
 llvm/lib/Transforms/IPO/AttributorAttributes.cpp | 2 +-
 llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp   | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 0fcb9739385b88f..83a91e823fa53ba 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -12150,7 +12150,7 @@ struct AAGlobalValueInfoFloating : public AAGlobalValueInfo {
       Uses.insert(&U);
       // TODO(captures): Make this more precise.
       CaptureInfo CI = DetermineUseCaptureKind(U, nullptr);
-      if (!capturesNothing(CI) && CI.isRetOnly()) {
+      if (capturesAnything(CI) && CI.isRetOnly()) {
         Follow = true;
         return true;
       }
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 91f8c0101e00c6f..3ce3d56f8a566dc 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -1552,7 +1552,7 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
           continue;
         CaptureInfo CI = DetermineUseCaptureKind(U, IsDereferenceableOrNull);
         // TODO(captures): Make this more precise.
-        if (!capturesNothing(CI)) {
+        if (capturesAnything(CI)) {
           if (CI.isRetOnly()) {
             Worklist.push_back(UI);
             continue;

>From 4c702d86f36b7fa9cbe129465ff14392281c2105 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Fri, 7 Feb 2025 17:38:53 +0100
Subject: [PATCH 5/8] Update unit test

---
 llvm/unittests/Analysis/CaptureTrackingTest.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/unittests/Analysis/CaptureTrackingTest.cpp b/llvm/unittests/Analysis/CaptureTrackingTest.cpp
index 8e677bb04ec32cf..38fd592ce84a888 100644
--- a/llvm/unittests/Analysis/CaptureTrackingTest.cpp
+++ b/llvm/unittests/Analysis/CaptureTrackingTest.cpp
@@ -77,10 +77,9 @@ TEST(CaptureTracking, MaxUsesToExplore) {
 struct CollectingCaptureTracker : public CaptureTracker {
   SmallVector<const Use *, 4> Captures;
   void tooManyUses() override { }
-  std::optional<CaptureComponents> captured(const Use *U,
-                                            CaptureInfo CI) override {
+  Action captured(const Use *U, CaptureInfo CI) override {
     Captures.push_back(U);
-    return continueDefault(CI);
+    return Continue;
   }
 };
 

>From fdaced620836a9b7cfded1831126eef5d4589c3e Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Fri, 7 Feb 2025 18:08:01 +0100
Subject: [PATCH 6/8] expand comment

---
 llvm/lib/Analysis/CaptureTracking.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp
index 98bb2d80908ae13..06978617525cf30 100644
--- a/llvm/lib/Analysis/CaptureTracking.cpp
+++ b/llvm/lib/Analysis/CaptureTracking.cpp
@@ -458,7 +458,9 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker,
         continue;
       case CaptureTracker::Continue:
         // Fall through to passthrough handling, but only if RetCC contains
-        // additional components that OtherCC does not.
+        // additional components that OtherCC does not. We assume that a
+        // capture at this point will be strictly more constraining than a
+        // later capture from following the return value.
         if (capturesNothing(RetCC & ~OtherCC))
           continue;
         break;

>From faa3ce549bf9a5d82e5dca5d0653cd1e7da2dde2 Mon Sep 17 00:00:00 2001
From: Nikita Popov <nikita.ppv at gmail.com>
Date: Fri, 7 Feb 2025 19:15:23 +0100
Subject: [PATCH 7/8] Fix address_is_null inference

---
 llvm/include/llvm/Analysis/CaptureTracking.h  |  5 ++-
 llvm/lib/Analysis/CaptureTracking.cpp         | 20 ++++++---
 .../Transforms/IPO/AttributorAttributes.cpp   |  8 ++--
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp |  3 +-
 .../Transforms/FunctionAttrs/nocapture.ll     | 44 ++++++++++++++++++-
 5 files changed, 68 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CaptureTracking.h b/llvm/include/llvm/Analysis/CaptureTracking.h
index b1eea6bcc07da26..05c32a26b93145a 100644
--- a/llvm/include/llvm/Analysis/CaptureTracking.h
+++ b/llvm/include/llvm/Analysis/CaptureTracking.h
@@ -130,8 +130,11 @@ namespace llvm {
   /// which components may be captured by following uses of the user of \p U.
   /// The \p IsDereferenceableOrNull callback is used to rule out capturing for
   /// certain comparisons.
+  ///
+  /// \p Base is the starting value of the capture analysis, which is
+  /// relevant for address_is_null captures.
   CaptureInfo
-  DetermineUseCaptureKind(const Use &U,
+  DetermineUseCaptureKind(const Use &U, const Value *Base,
                           llvm::function_ref<bool(Value *, const DataLayout &)>
                               IsDereferenceableOrNull);
 
diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp
index 06978617525cf30..516648bceef8ccf 100644
--- a/llvm/lib/Analysis/CaptureTracking.cpp
+++ b/llvm/lib/Analysis/CaptureTracking.cpp
@@ -280,7 +280,7 @@ Instruction *llvm::FindEarliestCapture(const Value *V, Function &F,
 }
 
 CaptureInfo llvm::DetermineUseCaptureKind(
-    const Use &U,
+    const Use &U, const Value *Base,
     function_ref<bool(Value *, const DataLayout &)> IsDereferenceableOrNull) {
   Instruction *I = dyn_cast<Instruction>(U.getUser());
 
@@ -378,14 +378,15 @@ CaptureInfo llvm::DetermineUseCaptureKind(
   case Instruction::ICmp: {
     unsigned Idx = U.getOperandNo();
     unsigned OtherIdx = 1 - Idx;
-    if (auto *CPN = dyn_cast<ConstantPointerNull>(I->getOperand(OtherIdx))) {
+    if (isa<ConstantPointerNull>(I->getOperand(OtherIdx)) &&
+        cast<ICmpInst>(I)->isEquality()) {
       // TODO(captures): Remove these special cases once we make use of
       // captures(address_is_null).
 
       // Don't count comparisons of a no-alias return value against null as
       // captures. This allows us to ignore comparisons of malloc results
       // with null, for example.
-      if (CPN->getType()->getAddressSpace() == 0)
+      if (U->getType()->getPointerAddressSpace() == 0)
         if (isNoAliasCall(U.get()->stripPointerCasts()))
           return CaptureInfo::none();
       if (!I->getFunction()->nullPointerIsDefined()) {
@@ -397,7 +398,16 @@ CaptureInfo llvm::DetermineUseCaptureKind(
         if (IsDereferenceableOrNull && IsDereferenceableOrNull(O, DL))
           return CaptureInfo::none();
       }
-      return CaptureInfo::otherOnly(CaptureComponents::AddressIsNull);
+
+      // Check whether this is a comparison of the base pointer against
+      // null. We can also strip inbounds GEPs, as inbounds preserves
+      // the null-ness of the pointer.
+      Value *Stripped = U.get();
+      if (!NullPointerIsDefined(I->getFunction(),
+                                U->getType()->getPointerAddressSpace()))
+        Stripped = Stripped->stripInBoundsOffsets();
+      if (Stripped == Base)
+        return CaptureInfo::otherOnly(CaptureComponents::AddressIsNull);
     }
 
     // Otherwise, be conservative. There are crazy ways to capture pointers
@@ -445,7 +455,7 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker,
   };
   while (!Worklist.empty()) {
     const Use *U = Worklist.pop_back_val();
-    CaptureInfo CI = DetermineUseCaptureKind(*U, IsDereferenceableOrNull);
+    CaptureInfo CI = DetermineUseCaptureKind(*U, V, IsDereferenceableOrNull);
     if (capturesNothing(CI))
       continue;
     CaptureComponents OtherCC = CI.getOtherComponents();
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 83a91e823fa53ba..fea861188614b87 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -3971,7 +3971,8 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl {
       //       is CGSCC runs. For those we would need to "allow" AANoCapture for
       //       a value in the module slice.
       // TODO(captures): Make this more precise.
-      CaptureInfo CI = DetermineUseCaptureKind(U, IsDereferenceableOrNull);
+      CaptureInfo CI =
+          DetermineUseCaptureKind(U, /*Base=*/nullptr, IsDereferenceableOrNull);
       if (capturesNothing(CI))
         return true;
       if (CI.isRetOnly()) {
@@ -6018,7 +6019,8 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) {
 
   auto UseCheck = [&](const Use &U, bool &Follow) -> bool {
     // TODO(captures): Make this more precise.
-    CaptureInfo CI = DetermineUseCaptureKind(U, IsDereferenceableOrNull);
+    CaptureInfo CI =
+        DetermineUseCaptureKind(U, /*Base=*/nullptr, IsDereferenceableOrNull);
     if (capturesNothing(CI))
       return true;
     if (CI.isRetOnly()) {
@@ -12149,7 +12151,7 @@ struct AAGlobalValueInfoFloating : public AAGlobalValueInfo {
     auto UsePred = [&](const Use &U, bool &Follow) -> bool {
       Uses.insert(&U);
       // TODO(captures): Make this more precise.
-      CaptureInfo CI = DetermineUseCaptureKind(U, nullptr);
+      CaptureInfo CI = DetermineUseCaptureKind(U, /*Base=*/nullptr, nullptr);
       if (capturesAnything(CI) && CI.isRetOnly()) {
         Follow = true;
         return true;
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 3ce3d56f8a566dc..e0f50e562d04772 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -1550,7 +1550,8 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,
         }
         if (!Visited.insert(&U).second)
           continue;
-        CaptureInfo CI = DetermineUseCaptureKind(U, IsDereferenceableOrNull);
+        CaptureInfo CI =
+            DetermineUseCaptureKind(U, AI, IsDereferenceableOrNull);
         // TODO(captures): Make this more precise.
         if (capturesAnything(CI)) {
           if (CI.isRetOnly()) {
diff --git a/llvm/test/Transforms/FunctionAttrs/nocapture.ll b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
index ccb5251fb4ea3ed..3f4fcca610d5d74 100644
--- a/llvm/test/Transforms/FunctionAttrs/nocapture.ll
+++ b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
@@ -811,6 +811,46 @@ define i1 @nocaptureInboundsGEPICmpRev(ptr %x) {
   ret i1 %2
 }
 
+define i1 @notInboundsGEPICmp(ptr %x) {
+; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+; FNATTRS-LABEL: define i1 @notInboundsGEPICmp
+; FNATTRS-SAME: (ptr readnone captures(address) [[X:%.*]]) #[[ATTR0]] {
+; FNATTRS-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[X]], i32 5
+; FNATTRS-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP1]], null
+; FNATTRS-NEXT:    ret i1 [[TMP2]]
+;
+; ATTRIBUTOR: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+; ATTRIBUTOR-LABEL: define i1 @notInboundsGEPICmp
+; ATTRIBUTOR-SAME: (ptr nofree readnone [[X:%.*]]) #[[ATTR0]] {
+; ATTRIBUTOR-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[X]], i32 5
+; ATTRIBUTOR-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP1]], null
+; ATTRIBUTOR-NEXT:    ret i1 [[TMP2]]
+;
+  %1 = getelementptr i32, ptr %x, i32 5
+  %2 = icmp eq ptr %1, null
+  ret i1 %2
+}
+
+define i1 @inboundsGEPICmpNullPointerDefined(ptr %x) null_pointer_is_valid {
+; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind null_pointer_is_valid willreturn memory(none)
+; FNATTRS-LABEL: define i1 @inboundsGEPICmpNullPointerDefined
+; FNATTRS-SAME: (ptr readnone captures(address) [[X:%.*]]) #[[ATTR16:[0-9]+]] {
+; FNATTRS-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[X]], i32 5
+; FNATTRS-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP1]], null
+; FNATTRS-NEXT:    ret i1 [[TMP2]]
+;
+; ATTRIBUTOR: Function Attrs: mustprogress nofree norecurse nosync nounwind null_pointer_is_valid willreturn memory(none)
+; ATTRIBUTOR-LABEL: define i1 @inboundsGEPICmpNullPointerDefined
+; ATTRIBUTOR-SAME: (ptr nofree readnone [[X:%.*]]) #[[ATTR12:[0-9]+]] {
+; ATTRIBUTOR-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[X]], i32 5
+; ATTRIBUTOR-NEXT:    [[TMP2:%.*]] = icmp eq ptr [[TMP1]], null
+; ATTRIBUTOR-NEXT:    ret i1 [[TMP2]]
+;
+  %1 = getelementptr i32, ptr %x, i32 5
+  %2 = icmp eq ptr %1, null
+  ret i1 %2
+}
+
 define i1 @nocaptureDereferenceableOrNullICmp(ptr dereferenceable_or_null(4) %x) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define noundef i1 @nocaptureDereferenceableOrNullICmp
@@ -831,13 +871,13 @@ define i1 @nocaptureDereferenceableOrNullICmp(ptr dereferenceable_or_null(4) %x)
 define i1 @captureDereferenceableOrNullICmp(ptr dereferenceable_or_null(4) %x) null_pointer_is_valid {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind null_pointer_is_valid willreturn memory(none)
 ; FNATTRS-LABEL: define noundef i1 @captureDereferenceableOrNullICmp
-; FNATTRS-SAME: (ptr readnone captures(address_is_null) dereferenceable_or_null(4) [[X:%.*]]) #[[ATTR16:[0-9]+]] {
+; FNATTRS-SAME: (ptr readnone captures(address_is_null) dereferenceable_or_null(4) [[X:%.*]]) #[[ATTR16]] {
 ; FNATTRS-NEXT:    [[TMP1:%.*]] = icmp eq ptr [[X]], null
 ; FNATTRS-NEXT:    ret i1 [[TMP1]]
 ;
 ; ATTRIBUTOR: Function Attrs: mustprogress nofree norecurse nosync nounwind null_pointer_is_valid willreturn memory(none)
 ; ATTRIBUTOR-LABEL: define i1 @captureDereferenceableOrNullICmp
-; ATTRIBUTOR-SAME: (ptr nofree readnone dereferenceable_or_null(4) [[X:%.*]]) #[[ATTR12:[0-9]+]] {
+; ATTRIBUTOR-SAME: (ptr nofree readnone dereferenceable_or_null(4) [[X:%.*]]) #[[ATTR12]] {
 ; ATTRIBUTOR-NEXT:    [[TMP1:%.*]] = icmp eq ptr [[X]], null
 ; ATTRIBUTOR-NEXT:    ret i1 [[TMP1]]
 ;

>From a874b84a9e6087e064bf3273979ce44fe2055a05 Mon Sep 17 00:00:00 2001
From: Nikita Popov <nikita.ppv at gmail.com>
Date: Fri, 7 Feb 2025 19:46:34 +0100
Subject: [PATCH 8/8] Add test for incorrect predicate

---
 llvm/test/Transforms/FunctionAttrs/nocapture.ll | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/llvm/test/Transforms/FunctionAttrs/nocapture.ll b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
index 3f4fcca610d5d74..5dd936b7587d532 100644
--- a/llvm/test/Transforms/FunctionAttrs/nocapture.ll
+++ b/llvm/test/Transforms/FunctionAttrs/nocapture.ll
@@ -771,6 +771,23 @@ define i1 @captureICmpRev(ptr %x) {
   ret i1 %1
 }
 
+define i1 @captureICmpWrongPred(ptr %x) {
+; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+; FNATTRS-LABEL: define i1 @captureICmpWrongPred
+; FNATTRS-SAME: (ptr readnone captures(address) [[X:%.*]]) #[[ATTR0]] {
+; FNATTRS-NEXT:    [[TMP1:%.*]] = icmp slt ptr [[X]], null
+; FNATTRS-NEXT:    ret i1 [[TMP1]]
+;
+; ATTRIBUTOR: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+; ATTRIBUTOR-LABEL: define i1 @captureICmpWrongPred
+; ATTRIBUTOR-SAME: (ptr nofree readnone [[X:%.*]]) #[[ATTR0]] {
+; ATTRIBUTOR-NEXT:    [[TMP1:%.*]] = icmp slt ptr [[X]], null
+; ATTRIBUTOR-NEXT:    ret i1 [[TMP1]]
+;
+  %1 = icmp slt ptr %x, null
+  ret i1 %1
+}
+
 define i1 @nocaptureInboundsGEPICmp(ptr %x) {
 ; FNATTRS: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
 ; FNATTRS-LABEL: define i1 @nocaptureInboundsGEPICmp



More information about the llvm-commits mailing list