[llvm] [IR] Introduce captures attribute (PR #116990)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 07:35:52 PST 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/116990

This introduces the `captures` attribute as described in: https://discourse.llvm.org/t/rfc-improvements-to-capture-tracking/81420

This initial patch only introduces the IR/bitcode support for the attribute and its in-memory representation as `CaptureInfo`. This will be followed by a patch to upgrade and remove the `nocapture` attribute, and then by actual inference/analysis support.

Based on the RFC feedback, I've used a syntax similar to the `memory` attribute, though the only "location" that can be specified is `ret`.

I've added some pretty extensive documentation to LangRef on the semantics. One non-obvious bit here is that using ptrtoint will not result in a "return-only" capture, even if the ptrtoint result is only used in the return value. Without this requirement we wouldn't be able to continue ordinary capture analysis on the return value.

>From 49a1a51a2eb85ee421393d731e0ff079bed8ac03 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Mon, 18 Nov 2024 17:16:57 +0100
Subject: [PATCH] [IR] Introduce captures attribute

This introduces the `captures` attribute as described in:
https://discourse.llvm.org/t/rfc-improvements-to-capture-tracking/81420

This initial patch only introduces the IR/bitcode support for
the attribute and its in-memory representation as `CaptureInfo`.
This will be followed by a patch to remove (and upgrade) the
`nocapture` attribute, and then by actual inference/analysis
support.

Based on the RFC feedback, I've used a syntax similar to the
`memory` attribute, though the only "location" that can be
specified right now is `ret`.

I've added some pretty extensive documentation to LangRef on the
semantics. One non-obvious bit here is that using ptrtoint will
not result in a "return-only" capture, even if the ptrtoint
result is only used in the return value. Without this requirement
we wouldn't be able to continue ordinary capture analysis on the
return value.
---
 llvm/docs/LangRef.rst                       | 129 ++++++++++++++++++--
 llvm/include/llvm/AsmParser/LLParser.h      |   1 +
 llvm/include/llvm/AsmParser/LLToken.h       |   5 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h    |   1 +
 llvm/include/llvm/IR/Attributes.h           |   7 ++
 llvm/include/llvm/IR/Attributes.td          |   3 +
 llvm/include/llvm/Support/ModRef.h          |  87 +++++++++++++
 llvm/lib/AsmParser/LLLexer.cpp              |   3 +
 llvm/lib/AsmParser/LLParser.cpp             |  49 ++++++++
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp   |   4 +
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp   |   2 +
 llvm/lib/IR/AttributeImpl.h                 |   1 +
 llvm/lib/IR/Attributes.cpp                  |  31 +++++
 llvm/lib/Support/ModRef.cpp                 |  26 ++++
 llvm/lib/Transforms/Utils/CodeExtractor.cpp |   1 +
 llvm/test/Assembler/captures-errors.ll      |  41 +++++++
 llvm/test/Assembler/captures.ll             |  70 +++++++++++
 llvm/test/Bitcode/attributes.ll             |   5 +
 llvm/unittests/IR/AttributesTest.cpp        |  12 ++
 19 files changed, 469 insertions(+), 9 deletions(-)
 create mode 100644 llvm/test/Assembler/captures-errors.ll
 create mode 100644 llvm/test/Assembler/captures.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 9f4c90ba82a419..df8599a818b6fc 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -1379,6 +1379,36 @@ Currently, only the following parameter attributes are defined:
     function, returning a pointer to allocated storage disjoint from the
     storage for any other object accessible to the caller.
 
+``captures(...)``
+    This attributes restrict the ways in which the callee may capture the
+    pointer. This is not a valid attribute for return values. This attribute
+    applies only to the particular copy of the pointer passed in this argument.
+
+    The arguments of ``captures`` is a list of captured pointer components,
+    which may be ``none``, or a combination of:
+
+    - ``address``: The integral address of the pointer.
+    - ``provenance``: The ability to access the pointer for both read and write
+      after the function returns.
+    - ``read_provenance``: The ability to access the pointer only for reads
+      after the function returns.
+
+    Additionally, it is possible to specify that the pointer is captured via
+    the return value only, by using ``caputres(ret: ...)``.
+
+    The `pointer capture section <pointercapture>` discusses these semantics
+    in more detail.
+
+    Some examples of how to use the attribute:
+
+    - ``captures(none)``: Pointer not captured.
+    - ``captures(address, provenance)``: Equivalent to omitting the attribute.
+    - ``captures(address)``: Address may be captured, but not provenance.
+    - ``captures(address, read_provenance)``: Both address and provenance
+      captured, but only for read-only access.
+    - ``captures(ret: address, provenance)``: Pointer captured through return
+      value only.
+
 .. _nocapture:
 
 ``nocapture``
@@ -3318,10 +3348,91 @@ Pointer Capture
 ---------------
 
 Given a function call and a pointer that is passed as an argument or stored in
-the memory before the call, a pointer is *captured* by the call if it makes a
-copy of any part of the pointer that outlives the call.
-To be precise, a pointer is captured if one or more of the following conditions
-hold:
+memory before the call, the call may capture two components of the pointer:
+
+  * The address of the pointer, which is its integral value. This also includes
+    parts of the address or any information about the address, including the
+    fact that it does not equal one specific value.
+  * The provenance of the pointer, which is the ability to perform memory
+    accesses through the pointer, in the sense of the :ref:`pointer aliasing
+    rules <pointeraliasing>`. We further distinguish whether only read acceses
+    are allowed, or both reads and writes.
+
+For example, the following function captures the address of ``%a``, because
+it is compared to a pointer, leaking information about the identitiy of the
+pointer:
+
+.. code-block:: llvm
+
+    @glb = global i8 0
+
+    define i1 @f(ptr %a) {
+      %c = icmp eq ptr %a, @glb
+      ret i1 %c
+    }
+
+The function does not capture the provenance of the pointer, because the
+``icmp`` instruction only operates on the pointer address. The following
+function captures both the address and provenance of the pointer, as both
+may be read from ``@glb`` after the function returns:
+
+.. code-block:: llvm
+
+    @glb = global ptr null
+
+    define void @f(ptr %a) {
+      store ptr %a, ptr @glb
+      ret void
+    }
+
+The following function captures *neither* the address nor the provenance of
+the pointer:
+
+.. code-block:: llvm
+
+    define i32 @f(ptr %a) {
+      %v = load i32, ptr %a
+      ret i32
+    }
+
+While address capture includes uses of the address within the body of the
+function, provenance capture refers exclusively to the ability to perform
+accesses *after* the function returns. Memory accesses within the function
+itself are not considered pointer captures.
+
+We can further say that the capture only occurs through a specific location.
+In the following example, the pointer (both address and provenance) is captured
+through the return value only:
+
+.. code-block:: llvm
+
+    define ptr @f(ptr %a) {
+      %gep = getelementptr i8, ptr %a, i64 4
+      ret ptr %gep
+    }
+
+However, we always consider direct inspection of the pointer address
+(e.g. using ``ptrtoint``) to be location-independent. The following example
+is *not* considered a return-only capture, even though the ``ptrtoint``
+ultimately only contribues to the return value:
+
+.. code-block:: llvm
+
+    @lookup = constant [4 x i8] [i8 0, i8 1, i8 2, i8 3]
+
+    define ptr @f(ptr %a) {
+      %a.addr = ptrtoint ptr %a to i64
+      %mask = and i64 %a.addr, 3
+      %gep = getelementptr i8, ptr @lookup, i64 %mask
+      ret ptr %gep
+    }
+
+This definition is chosen to allow capture analysis to continue with the return
+value in the usual fashion.
+
+The following describes possible ways to capture a pointer in more detail,
+where unqualified uses of the word "capture" refer to capturing both address
+and provenance.
 
 1. The call stores any bit of the pointer carrying information into a place,
    and the stored bits can be read from the place by the caller after this call
@@ -3360,13 +3471,14 @@ hold:
     @lock = global i1 true
 
     define void @f(ptr %a) {
-      store ptr %a, ptr* @glb
+      store ptr %a, ptr @glb
       store atomic i1 false, ptr @lock release ; %a is captured because another thread can safely read @glb
       store ptr null, ptr @glb
       ret void
     }
 
-3. The call's behavior depends on any bit of the pointer carrying information.
+3. The call's behavior depends on any bit of the pointer carrying information
+   (address capture only).
 
 .. code-block:: llvm
 
@@ -3374,7 +3486,7 @@ hold:
 
     define void @f(ptr %a) {
       %c = icmp eq ptr %a, @glb
-      br i1 %c, label %BB_EXIT, label %BB_CONTINUE ; escapes %a
+      br i1 %c, label %BB_EXIT, label %BB_CONTINUE ; captures address of %a only
     BB_EXIT:
       call void @exit()
       unreachable
@@ -3382,8 +3494,7 @@ hold:
       ret void
     }
 
-4. The pointer is used in a volatile access as its address.
-
+4. The pointer is used as the pointer operand of a volatile access.
 
 .. _volatile:
 
diff --git a/llvm/include/llvm/AsmParser/LLParser.h b/llvm/include/llvm/AsmParser/LLParser.h
index 1ef8b8ffc39660..bc95a57da3c2ae 100644
--- a/llvm/include/llvm/AsmParser/LLParser.h
+++ b/llvm/include/llvm/AsmParser/LLParser.h
@@ -376,6 +376,7 @@ namespace llvm {
                                     bool inAttrGrp, LocTy &BuiltinLoc);
     bool parseRangeAttr(AttrBuilder &B);
     bool parseInitializesAttr(AttrBuilder &B);
+    bool parseCapturesAttr(AttrBuilder &B);
     bool parseRequiredTypeAttr(AttrBuilder &B, lltok::Kind AttrToken,
                                Attribute::AttrKind AttrKind);
 
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 178c911120b4ce..48b8aa0158c660 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -207,6 +207,11 @@ enum Kind {
   kw_inaccessiblememonly,
   kw_inaccessiblemem_or_argmemonly,
 
+  // Captures attribute:
+  kw_address,
+  kw_provenance,
+  kw_read_provenance,
+
   // nofpclass attribute:
   kw_all,
   kw_nan,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index a0fb32f67e3858..4a7af55fce871a 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -783,6 +783,7 @@ enum AttributeKindCodes {
   ATTR_KIND_CORO_ELIDE_SAFE = 98,
   ATTR_KIND_NO_EXT = 99,
   ATTR_KIND_NO_DIVERGENCE_SOURCE = 100,
+  ATTR_KIND_CAPTURES = 101,
 };
 
 enum ComdatSelectionKindCodes {
diff --git a/llvm/include/llvm/IR/Attributes.h b/llvm/include/llvm/IR/Attributes.h
index 2755ced404dddb..7612e553fe32e6 100644
--- a/llvm/include/llvm/IR/Attributes.h
+++ b/llvm/include/llvm/IR/Attributes.h
@@ -284,6 +284,9 @@ class Attribute {
   /// Returns memory effects.
   MemoryEffects getMemoryEffects() const;
 
+  /// Returns information from captures attribute.
+  CaptureInfo getCaptureInfo() const;
+
   /// Return the FPClassTest for nofpclass
   FPClassTest getNoFPClass() const;
 
@@ -436,6 +439,7 @@ class AttributeSet {
   UWTableKind getUWTableKind() const;
   AllocFnKind getAllocKind() const;
   MemoryEffects getMemoryEffects() const;
+  CaptureInfo getCaptureInfo() const;
   FPClassTest getNoFPClass() const;
   std::string getAsString(bool InAttrGrp = false) const;
 
@@ -1260,6 +1264,9 @@ class AttrBuilder {
   /// Add memory effect attribute.
   AttrBuilder &addMemoryAttr(MemoryEffects ME);
 
+  /// Add captures attribute.
+  AttrBuilder &addCapturesAttr(CaptureInfo CI);
+
   // Add nofpclass attribute
   AttrBuilder &addNoFPClassAttr(FPClassTest NoFPClassMask);
 
diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td
index 49f4527bde66e7..e6e9846c412a7f 100644
--- a/llvm/include/llvm/IR/Attributes.td
+++ b/llvm/include/llvm/IR/Attributes.td
@@ -183,6 +183,9 @@ def NoCallback : EnumAttr<"nocallback", IntersectAnd, [FnAttr]>;
 /// Function creates no aliases of pointer.
 def NoCapture : EnumAttr<"nocapture", IntersectAnd, [ParamAttr]>;
 
+/// Specify how the pointer may be captured.
+def Captures : IntAttr<"captures", IntersectCustom, [ParamAttr]>;
+
 /// Function is not a source of divergence.
 def NoDivergenceSource : EnumAttr<"nodivergencesource", IntersectAnd, [FnAttr]>;
 
diff --git a/llvm/include/llvm/Support/ModRef.h b/llvm/include/llvm/Support/ModRef.h
index 5a9d80c87ae27a..d610aa5eaac6b5 100644
--- a/llvm/include/llvm/Support/ModRef.h
+++ b/llvm/include/llvm/Support/ModRef.h
@@ -273,6 +273,93 @@ raw_ostream &operator<<(raw_ostream &OS, MemoryEffects RMRB);
 // Legacy alias.
 using FunctionModRefBehavior = MemoryEffects;
 
+/// Components of the pointer that may be captured.
+enum class CaptureComponents : uint8_t {
+  None = 0,
+  Address = (1 << 0),
+  ReadProvenance = (1 << 1),
+  Provenance = (1 << 2) | ReadProvenance,
+  All = Address | Provenance,
+  LLVM_MARK_AS_BITMASK_ENUM(Provenance),
+};
+
+inline bool capturesNothing(CaptureComponents CC) {
+  return CC == CaptureComponents::None;
+}
+
+inline bool capturesAnything(CaptureComponents CC) {
+  return CC != CaptureComponents::None;
+}
+
+inline bool capturesAddress(CaptureComponents CC) {
+  return (CC & CaptureComponents::Address) != CaptureComponents::None;
+}
+
+inline bool capturesReadProvenanceOnly(CaptureComponents CC) {
+  return (CC & CaptureComponents::Provenance) ==
+         CaptureComponents::ReadProvenance;
+}
+
+inline bool capturesFullProvenance(CaptureComponents CC) {
+  return (CC & CaptureComponents::Provenance) == CaptureComponents::Provenance;
+}
+
+raw_ostream &operator<<(raw_ostream &OS, CaptureComponents CC);
+
+/// Represents which components of the pointer may be captured and whether
+/// the capture is via the return value only. This represents the captures(...)
+/// attribute in IR.
+///
+/// For more information on the precise semantics see LangRef.
+class CaptureInfo {
+  CaptureComponents Components;
+  bool ReturnOnly;
+
+public:
+  CaptureInfo(CaptureComponents Components, bool ReturnOnly = false)
+      : Components(Components),
+        ReturnOnly(capturesAnything(Components) && ReturnOnly) {}
+
+  /// Create CaptureInfo that may capture all components of the pointer.
+  static CaptureInfo all() { return CaptureInfo(CaptureComponents::All); }
+
+  /// Get the potentially captured components of the pointer.
+  operator CaptureComponents() const { return Components; }
+
+  /// Whether the pointer is captured through the return value only.
+  bool isReturnOnly() const { return ReturnOnly; }
+
+  bool operator==(CaptureInfo Other) const {
+    return Components == Other.Components && ReturnOnly == Other.ReturnOnly;
+  }
+
+  bool operator!=(CaptureInfo Other) const { return !(*this == Other); }
+
+  /// Compute union of CaptureInfos.
+  CaptureInfo operator|(CaptureInfo Other) const {
+    return CaptureInfo(Components | Other.Components,
+                       ReturnOnly && Other.ReturnOnly);
+  }
+
+  /// Compute intersection of CaptureInfos.
+  CaptureInfo operator&(CaptureInfo Other) const {
+    return CaptureInfo(Components & Other.Components,
+                       ReturnOnly || Other.ReturnOnly);
+  }
+
+  static CaptureInfo createFromIntValue(uint32_t Data) {
+    return CaptureInfo(CaptureComponents(Data >> 1), Data & 1);
+  }
+
+  /// Convert CaptureInfo into an encoded integer value (used by captures
+  /// attribute).
+  uint32_t toIntValue() const {
+    return (uint32_t(Components) << 1) | ReturnOnly;
+  }
+};
+
+raw_ostream &operator<<(raw_ostream &OS, CaptureInfo Info);
+
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 1b8e033134f51b..7b6be79723b96b 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -704,6 +704,9 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(argmemonly);
   KEYWORD(inaccessiblememonly);
   KEYWORD(inaccessiblemem_or_argmemonly);
+  KEYWORD(address);
+  KEYWORD(provenance);
+  KEYWORD(read_provenance);
 
   // nofpclass attribute
   KEYWORD(all);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index b8a8df71d4de21..880fe543810570 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -1644,6 +1644,8 @@ bool LLParser::parseEnumAttribute(Attribute::AttrKind Attr, AttrBuilder &B,
     return parseRangeAttr(B);
   case Attribute::Initializes:
     return parseInitializesAttr(B);
+  case Attribute::Captures:
+    return parseCapturesAttr(B);
   default:
     B.addAttribute(Attr);
     Lex.Lex();
@@ -3165,6 +3167,53 @@ bool LLParser::parseInitializesAttr(AttrBuilder &B) {
   return false;
 }
 
+bool LLParser::parseCapturesAttr(AttrBuilder &B) {
+  CaptureComponents CC = CaptureComponents::None;
+  bool ReturnOnly = false;
+
+  // We use syntax like captures(ret: address, provenance), so the colon
+  // should not be interpreted as a label terminator.
+  Lex.setIgnoreColonInIdentifiers(true);
+  auto _ = make_scope_exit([&] { Lex.setIgnoreColonInIdentifiers(false); });
+
+  Lex.Lex();
+  if (parseToken(lltok::lparen, "expected '('"))
+    return true;
+
+  if (EatIfPresent(lltok::kw_ret)) {
+    if (parseToken(lltok::colon, "expected ':'"))
+      return true;
+
+    ReturnOnly = true;
+  }
+
+  if (EatIfPresent(lltok::kw_none)) {
+    if (parseToken(lltok::rparen, "expected ')'"))
+      return true;
+  } else {
+    while (true) {
+      if (EatIfPresent(lltok::kw_address))
+        CC |= CaptureComponents::Address;
+      else if (EatIfPresent(lltok::kw_provenance))
+        CC |= CaptureComponents::Provenance;
+      else if (EatIfPresent(lltok::kw_read_provenance))
+        CC |= CaptureComponents::ReadProvenance;
+      else
+        return tokError(
+            "expected one of 'address', 'provenance' or 'read_provenance'");
+
+      if (EatIfPresent(lltok::rparen))
+        break;
+
+      if (parseToken(lltok::comma, "expected ',' or ')'"))
+        return true;
+    }
+  }
+
+  B.addCapturesAttr(CaptureInfo(CC, ReturnOnly));
+  return false;
+}
+
 /// parseOptionalOperandBundles
 ///    ::= /*empty*/
 ///    ::= '[' OperandBundle [, OperandBundle ]* ']'
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 3e6abacac27261..e74422cfe90fb5 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2215,6 +2215,8 @@ static Attribute::AttrKind getAttrFromCode(uint64_t Code) {
     return Attribute::CoroElideSafe;
   case bitc::ATTR_KIND_NO_EXT:
     return Attribute::NoExt;
+  case bitc::ATTR_KIND_CAPTURES:
+    return Attribute::Captures;
   }
 }
 
@@ -2354,6 +2356,8 @@ Error BitcodeReader::parseAttributeGroupBlock() {
             B.addAllocKindAttr(static_cast<AllocFnKind>(Record[++i]));
           else if (Kind == Attribute::Memory)
             B.addMemoryAttr(MemoryEffects::createFromIntValue(Record[++i]));
+          else if (Kind == Attribute::Captures)
+            B.addCapturesAttr(CaptureInfo::createFromIntValue(Record[++i]));
           else if (Kind == Attribute::NoFPClass)
             B.addNoFPClassAttr(
                 static_cast<FPClassTest>(Record[++i] & fcAllFlags));
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 24a4c2e8303d5a..83a5f753ee5685 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -902,6 +902,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
     return bitc::ATTR_KIND_INITIALIZES;
   case Attribute::NoExt:
     return bitc::ATTR_KIND_NO_EXT;
+  case Attribute::Captures:
+    return bitc::ATTR_KIND_CAPTURES;
   case Attribute::EndAttrKinds:
     llvm_unreachable("Can not encode end-attribute kinds marker.");
   case Attribute::None:
diff --git a/llvm/lib/IR/AttributeImpl.h b/llvm/lib/IR/AttributeImpl.h
index 82c501dcafcb7f..59cc489ade40de 100644
--- a/llvm/lib/IR/AttributeImpl.h
+++ b/llvm/lib/IR/AttributeImpl.h
@@ -346,6 +346,7 @@ class AttributeSetNode final
   UWTableKind getUWTableKind() const;
   AllocFnKind getAllocKind() const;
   MemoryEffects getMemoryEffects() const;
+  CaptureInfo getCaptureInfo() const;
   FPClassTest getNoFPClass() const;
   std::string getAsString(bool InAttrGrp) const;
   Type *getAttributeType(Attribute::AttrKind Kind) const;
diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp
index e9daa01b899e8f..052998698321a6 100644
--- a/llvm/lib/IR/Attributes.cpp
+++ b/llvm/lib/IR/Attributes.cpp
@@ -487,6 +487,12 @@ MemoryEffects Attribute::getMemoryEffects() const {
   return MemoryEffects::createFromIntValue(pImpl->getValueAsInt());
 }
 
+CaptureInfo Attribute::getCaptureInfo() const {
+  assert(hasAttribute(Attribute::Captures) &&
+         "Can only call getCaptureInfo() on captures attribute");
+  return CaptureInfo::createFromIntValue(pImpl->getValueAsInt());
+}
+
 FPClassTest Attribute::getNoFPClass() const {
   assert(hasAttribute(Attribute::NoFPClass) &&
          "Can only call getNoFPClass() on nofpclass attribute");
@@ -647,6 +653,13 @@ std::string Attribute::getAsString(bool InAttrGrp) const {
     return Result;
   }
 
+  if (hasAttribute(Attribute::Captures)) {
+    std::string Result;
+    raw_string_ostream OS(Result);
+    OS << getCaptureInfo();
+    return Result;
+  }
+
   if (hasAttribute(Attribute::NoFPClass)) {
     std::string Result = "nofpclass";
     raw_string_ostream OS(Result);
@@ -1050,6 +1063,10 @@ AttributeSet::intersectWith(LLVMContext &C, AttributeSet Other) const {
         Intersected.addMemoryAttr(Attr0.getMemoryEffects() |
                                   Attr1.getMemoryEffects());
         break;
+      case Attribute::Captures:
+        Intersected.addCapturesAttr(Attr0.getCaptureInfo() |
+                                    Attr1.getCaptureInfo());
+        break;
       case Attribute::NoFPClass:
         Intersected.addNoFPClassAttr(Attr0.getNoFPClass() &
                                      Attr1.getNoFPClass());
@@ -1170,6 +1187,10 @@ MemoryEffects AttributeSet::getMemoryEffects() const {
   return SetNode ? SetNode->getMemoryEffects() : MemoryEffects::unknown();
 }
 
+CaptureInfo AttributeSet::getCaptureInfo() const {
+  return SetNode ? SetNode->getCaptureInfo() : CaptureInfo::all();
+}
+
 FPClassTest AttributeSet::getNoFPClass() const {
   return SetNode ? SetNode->getNoFPClass() : fcNone;
 }
@@ -1358,6 +1379,12 @@ MemoryEffects AttributeSetNode::getMemoryEffects() const {
   return MemoryEffects::unknown();
 }
 
+CaptureInfo AttributeSetNode::getCaptureInfo() const {
+  if (auto A = findEnumAttribute(Attribute::Captures))
+    return A->getCaptureInfo();
+  return CaptureInfo::all();
+}
+
 FPClassTest AttributeSetNode::getNoFPClass() const {
   if (auto A = findEnumAttribute(Attribute::NoFPClass))
     return A->getNoFPClass();
@@ -2190,6 +2217,10 @@ AttrBuilder &AttrBuilder::addMemoryAttr(MemoryEffects ME) {
   return addRawIntAttr(Attribute::Memory, ME.toIntValue());
 }
 
+AttrBuilder &AttrBuilder::addCapturesAttr(CaptureInfo CI) {
+  return addRawIntAttr(Attribute::Captures, CI.toIntValue());
+}
+
 AttrBuilder &AttrBuilder::addNoFPClassAttr(FPClassTest Mask) {
   if (Mask == fcNone)
     return *this;
diff --git a/llvm/lib/Support/ModRef.cpp b/llvm/lib/Support/ModRef.cpp
index a4eb70edd38d10..da7b060ddda55a 100644
--- a/llvm/lib/Support/ModRef.cpp
+++ b/llvm/lib/Support/ModRef.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/Support/ModRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
 
 using namespace llvm;
 
@@ -50,3 +51,28 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, MemoryEffects ME) {
   });
   return OS;
 }
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, CaptureComponents CC) {
+  if (capturesNothing(CC)) {
+    OS << "none";
+    return OS;
+  }
+
+  ListSeparator LS;
+  if (capturesAddress(CC))
+    OS << LS << "address";
+  if (capturesReadProvenanceOnly(CC))
+    OS << LS << "read_provenance";
+  if (capturesFullProvenance(CC))
+    OS << LS << "provenance";
+
+  return OS;
+}
+
+raw_ostream &llvm::operator<<(raw_ostream &OS, CaptureInfo CI) {
+  OS << "captures(";
+  if (CI.isReturnOnly())
+    OS << "ret: ";
+  OS << CaptureComponents(CI) << ")";
+  return OS;
+}
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 6539f924c2edf4..0b9b8ad9fe6f52 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -956,6 +956,7 @@ Function *CodeExtractor::constructFunctionDeclaration(
       case Attribute::AllocatedPointer:
       case Attribute::AllocAlign:
       case Attribute::ByVal:
+      case Attribute::Captures:
       case Attribute::Dereferenceable:
       case Attribute::DereferenceableOrNull:
       case Attribute::ElementType:
diff --git a/llvm/test/Assembler/captures-errors.ll b/llvm/test/Assembler/captures-errors.ll
new file mode 100644
index 00000000000000..d94f4b33af0fd0
--- /dev/null
+++ b/llvm/test/Assembler/captures-errors.ll
@@ -0,0 +1,41 @@
+; RUN: split-file --leading-lines %s %t
+; RUN: not llvm-as < %t/missing-lparen.ll 2>&1 | FileCheck %s --check-prefix=CHECK-MISSING-LPAREN
+; RUN: not llvm-as < %t/missing-rparen.ll 2>&1 | FileCheck %s --check-prefix=CHECK-MISSING-RPAREN
+; RUN: not llvm-as < %t/missing-rparen-none.ll 2>&1 | FileCheck %s --check-prefix=CHECK-MISSING-RPAREN-NONE
+; RUN: not llvm-as < %t/missing-colon.ll 2>&1 | FileCheck %s --check-prefix=CHECK-MISSING-COLON
+; RUN: not llvm-as < %t/invalid-component.ll 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID-COMPONENT
+
+;--- missing-lparen.ll
+
+; CHECK-MISSING-LPAREN: <stdin>:[[@LINE+1]]:32: error: expected '('
+define void @test(ptr captures %p) {
+  ret void
+}
+
+;--- missing-rparen.ll
+
+; CHECK-MISSING-RPAREN: <stdin>:[[@LINE+1]]:40: error: expected ',' or ')'
+define void @test(ptr captures(address %p) {
+  ret void
+}
+
+;--- missing-rparen-none.ll
+
+; CHECK-MISSING-RPAREN-NONE: <stdin>:[[@LINE+1]]:37: error: expected ')'
+define void @test(ptr captures(none %p) {
+  ret void
+}
+
+;--- missing-colon.ll
+
+; CHECK-MISSING-COLON: <stdin>:[[@LINE+1]]:36: error: expected ':'
+define void @test(ptr captures(ret address) %p) {
+  ret void
+}
+
+;--- invalid-component.ll
+
+; CHECK-INVALID-COMPONENT: <stdin>:[[@LINE+1]]:32: error: expected one of 'address', 'provenance' or 'read_provenance'
+define void @test(ptr captures(foo) %p) {
+  ret void
+}
diff --git a/llvm/test/Assembler/captures.ll b/llvm/test/Assembler/captures.ll
new file mode 100644
index 00000000000000..3f2e7eec7ca626
--- /dev/null
+++ b/llvm/test/Assembler/captures.ll
@@ -0,0 +1,70 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S < %s | FileCheck %s
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s
+
+define void @test_none(ptr captures(none) %p) {
+; CHECK-LABEL: define void @test_none(
+; CHECK-SAME: ptr captures(none) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+define void @test_address(ptr captures(address) %p) {
+; CHECK-LABEL: define void @test_address(
+; CHECK-SAME: ptr captures(address) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+define void @test_address_provenance(ptr captures(address, provenance) %p) {
+; CHECK-LABEL: define void @test_address_provenance(
+; CHECK-SAME: ptr captures(address, provenance) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+define void @test_address_read_provenance(ptr captures(address, read_provenance) %p) {
+; CHECK-LABEL: define void @test_address_read_provenance(
+; CHECK-SAME: ptr captures(address, read_provenance) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+define void @test_ret(ptr captures(ret: address, provenance) %p) {
+; CHECK-LABEL: define void @test_ret(
+; CHECK-SAME: ptr captures(ret: address, provenance) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+; Duplicates callpse into one.
+define void @test_duplicate(ptr captures(address, address) %p) {
+; CHECK-LABEL: define void @test_duplicate(
+; CHECK-SAME: ptr captures(address) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+; read_provenance is a subset of provenance.
+define void @test_duplicate_read_provenance(ptr captures(read_provenance, provenance) %p) {
+; CHECK-LABEL: define void @test_duplicate_read_provenance(
+; CHECK-SAME: ptr captures(provenance) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+; Return-only none is same as plain none.
+define void @test_ret_none(ptr captures(ret: none) %p) {
+; CHECK-LABEL: define void @test_ret_none(
+; CHECK-SAME: ptr captures(none) [[P:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
diff --git a/llvm/test/Bitcode/attributes.ll b/llvm/test/Bitcode/attributes.ll
index 492de663884df4..1da9291c719964 100644
--- a/llvm/test/Bitcode/attributes.ll
+++ b/llvm/test/Bitcode/attributes.ll
@@ -562,6 +562,11 @@ define void @initializes(ptr initializes((-4, 0), (4, 8)) %a) {
   ret void
 }
 
+; CHECK: define void @captures(ptr captures(address) %p)
+define void @captures(ptr captures(address) %p) {
+  ret void
+}
+
 ; CHECK: attributes #0 = { noreturn }
 ; CHECK: attributes #1 = { nounwind }
 ; CHECK: attributes #2 = { memory(none) }
diff --git a/llvm/unittests/IR/AttributesTest.cpp b/llvm/unittests/IR/AttributesTest.cpp
index f73f2b20e9fea5..8b5800e6cf0dd6 100644
--- a/llvm/unittests/IR/AttributesTest.cpp
+++ b/llvm/unittests/IR/AttributesTest.cpp
@@ -437,6 +437,12 @@ TEST(Attributes, SetIntersect) {
         break;
       case Attribute::Range:
         break;
+      case Attribute::Captures:
+        V0 = CaptureInfo(CaptureComponents::Address, /*ReturnOnly=*/false)
+                 .toIntValue();
+        V1 = CaptureInfo(CaptureComponents::ReadProvenance, /*ReturnOnly=*/true)
+                 .toIntValue();
+        break;
       default:
         ASSERT_FALSE(true);
       }
@@ -516,6 +522,12 @@ TEST(Attributes, SetIntersect) {
         ASSERT_EQ(Res->getAttribute(Kind).getRange(),
                   ConstantRange(APInt(32, 0), APInt(32, 20)));
         break;
+      case Attribute::Captures:
+        ASSERT_EQ(Res->getCaptureInfo(),
+                  CaptureInfo(CaptureComponents::Address |
+                                  CaptureComponents::ReadProvenance,
+                              /*ReturnOnly=*/false));
+        break;
       default:
         ASSERT_FALSE(true);
       }



More information about the llvm-commits mailing list