[clang] [llvm] [Clang][CodeGen] Implement code generation for __builtin_infer_alloc_token() (PR #156842)

Marco Elver via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 28 05:56:49 PDT 2025


https://github.com/melver updated https://github.com/llvm/llvm-project/pull/156842

>From 019499fd0f441ee14e71392f6d1b1219734163a3 Mon Sep 17 00:00:00 2001
From: Marco Elver <elver at google.com>
Date: Wed, 15 Oct 2025 15:13:27 +0200
Subject: [PATCH 1/2] force rebase

---
 clang/docs/AllocToken.rst                     | 43 ++++++++---
 clang/docs/ReleaseNotes.rst                   |  3 +
 clang/lib/CodeGen/BackendUtil.cpp             | 28 ++++---
 clang/lib/CodeGen/CGBuiltin.cpp               |  9 +++
 clang/test/CodeGen/lto-newpm-pipeline.c       |  8 +-
 clang/test/CodeGenCXX/alloc-token-builtin.cpp | 77 +++++++++++++++++++
 6 files changed, 146 insertions(+), 22 deletions(-)
 create mode 100644 clang/test/CodeGenCXX/alloc-token-builtin.cpp

diff --git a/clang/docs/AllocToken.rst b/clang/docs/AllocToken.rst
index b65e18ccfa967..1a740e5e22c29 100644
--- a/clang/docs/AllocToken.rst
+++ b/clang/docs/AllocToken.rst
@@ -49,6 +49,39 @@ change or removal. These may (experimentally) be selected with ``-Xclang
 * ``increment``: This mode assigns a simple, incrementally increasing token ID
   to each allocation site.
 
+The following command-line options affect generated token IDs:
+
+* ``-falloc-token-max=<N>``
+    Configures the maximum number of tokens. No max by default (tokens bounded
+    by ``SIZE_MAX``).
+
+Querying Token IDs with ``__builtin_infer_alloc_token``
+=======================================================
+
+For use cases where the token ID must be known at compile time, Clang provides
+a builtin function:
+
+.. code-block:: c
+
+    size_t __builtin_infer_alloc_token(<args>, ...);
+
+This builtin returns the token ID inferred from its argument expressions, which
+mirror arguments normally passed to any allocation function. The argument
+expressions are **unevaluated**, so it can be used with expressions that would
+have side effects without any runtime impact.
+
+For example, it can be used as follows:
+
+.. code-block:: c
+
+    struct MyType { ... };
+    void *__partition_alloc(size_t size, size_t partition);
+    #define partition_alloc(...) __partition_alloc(__VA_ARGS__, __builtin_infer_alloc_token(__VA_ARGS__))
+
+    void foo(void) {
+        MyType *x = partition_alloc(sizeof(*x));
+    }
+
 Allocation Token Instrumentation
 ================================
 
@@ -70,16 +103,6 @@ example:
     // Instrumented:
     ptr = __alloc_token_malloc(size, <token id>);
 
-The following command-line options affect generated token IDs:
-
-* ``-falloc-token-max=<N>``
-    Configures the maximum number of tokens. No max by default (tokens bounded
-    by ``SIZE_MAX``).
-
-    .. code-block:: console
-
-        % clang++ -fsanitize=alloc-token -falloc-token-max=512 example.cc
-
 Runtime Interface
 -----------------
 
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index db2b0f6fd5027..5bd3a071ee33e 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -268,6 +268,9 @@ Non-comprehensive list of changes in this release
   allocator-level heap organization strategies. A feature to instrument all
   allocation functions with a token ID can be enabled via the
   ``-fsanitize=alloc-token`` flag.
+- A builtin ``__builtin_infer_alloc_token(<args>, ...)`` is provided to allow
+  compile-time querying of allocation token IDs, where the builtin arguments
+  mirror those normally passed to an allocation function.
 
 New Compiler Flags
 ------------------
diff --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp
index 23ad11ac9f792..2be61f2479e44 100644
--- a/clang/lib/CodeGen/BackendUtil.cpp
+++ b/clang/lib/CodeGen/BackendUtil.cpp
@@ -803,16 +803,6 @@ static void addSanitizers(const Triple &TargetTriple,
       MPM.addPass(DataFlowSanitizerPass(LangOpts.NoSanitizeFiles,
                                         PB.getVirtualFileSystemPtr()));
     }
-
-    if (LangOpts.Sanitize.has(SanitizerKind::AllocToken)) {
-      if (Level == OptimizationLevel::O0) {
-        // The default pass builder only infers libcall function attrs when
-        // optimizing, so we insert it here because we need it for accurate
-        // memory allocation function detection.
-        MPM.addPass(InferFunctionAttrsPass());
-      }
-      MPM.addPass(AllocTokenPass(getAllocTokenOptions(LangOpts, CodeGenOpts)));
-    }
   };
   if (ClSanitizeOnOptimizerEarlyEP) {
     PB.registerOptimizerEarlyEPCallback(
@@ -855,6 +845,23 @@ static void addSanitizers(const Triple &TargetTriple,
   }
 }
 
+static void addAllocTokenPass(const Triple &TargetTriple,
+                              const CodeGenOptions &CodeGenOpts,
+                              const LangOptions &LangOpts, PassBuilder &PB) {
+  PB.registerOptimizerLastEPCallback([&](ModulePassManager &MPM,
+                                         OptimizationLevel Level,
+                                         ThinOrFullLTOPhase) {
+    if (Level == OptimizationLevel::O0 &&
+        LangOpts.Sanitize.has(SanitizerKind::AllocToken)) {
+      // The default pass builder only infers libcall function attrs when
+      // optimizing, so we insert it here because we need it for accurate
+      // memory allocation function detection with -fsanitize=alloc-token.
+      MPM.addPass(InferFunctionAttrsPass());
+    }
+    MPM.addPass(AllocTokenPass(getAllocTokenOptions(LangOpts, CodeGenOpts)));
+  });
+}
+
 void EmitAssemblyHelper::RunOptimizationPipeline(
     BackendAction Action, std::unique_ptr<raw_pwrite_stream> &OS,
     std::unique_ptr<llvm::ToolOutputFile> &ThinLinkOS, BackendConsumer *BC) {
@@ -1109,6 +1116,7 @@ void EmitAssemblyHelper::RunOptimizationPipeline(
     if (!IsThinLTOPostLink) {
       addSanitizers(TargetTriple, CodeGenOpts, LangOpts, PB);
       addKCFIPass(TargetTriple, LangOpts, PB);
+      addAllocTokenPass(TargetTriple, CodeGenOpts, LangOpts, PB);
     }
 
     if (std::optional<GCOVOptions> Options =
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 9ee810c9d5775..8a85789a256d5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -4525,6 +4525,15 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     return RValue::get(AI);
   }
 
+  case Builtin::BI__builtin_infer_alloc_token: {
+    llvm::MDNode *MDN = buildAllocToken(E);
+    llvm::Value *MDV = MetadataAsValue::get(getLLVMContext(), MDN);
+    llvm::Function *F =
+        CGM.getIntrinsic(llvm::Intrinsic::alloc_token_id, {IntPtrTy});
+    llvm::CallBase *TokenID = Builder.CreateCall(F, MDV);
+    return RValue::get(TokenID);
+  }
+
   case Builtin::BIbzero:
   case Builtin::BI__builtin_bzero: {
     Address Dest = EmitPointerWithAlignment(E->getArg(0));
diff --git a/clang/test/CodeGen/lto-newpm-pipeline.c b/clang/test/CodeGen/lto-newpm-pipeline.c
index ea9784a76f923..dceaaf136ebfc 100644
--- a/clang/test/CodeGen/lto-newpm-pipeline.c
+++ b/clang/test/CodeGen/lto-newpm-pipeline.c
@@ -32,10 +32,12 @@
 // CHECK-FULL-O0-NEXT: Running pass: AlwaysInlinerPass
 // CHECK-FULL-O0-NEXT: Running analysis: ProfileSummaryAnalysis
 // CHECK-FULL-O0-NEXT: Running pass: CoroConditionalWrapper
+// CHECK-FULL-O0-NEXT: Running pass: AllocTokenPass
+// CHECK-FULL-O0-NEXT: Running analysis: OptimizationRemarkEmitterAnalysis
+// CHECK-FULL-O0-NEXT: Running analysis: TargetLibraryAnalysis
 // CHECK-FULL-O0-NEXT: Running pass: CanonicalizeAliasesPass
 // CHECK-FULL-O0-NEXT: Running pass: NameAnonGlobalPass
 // CHECK-FULL-O0-NEXT: Running pass: AnnotationRemarksPass
-// CHECK-FULL-O0-NEXT: Running analysis: TargetLibraryAnalysis
 // CHECK-FULL-O0-NEXT: Running pass: VerifierPass
 // CHECK-FULL-O0-NEXT: Running pass: BitcodeWriterPass
 
@@ -46,10 +48,12 @@
 // CHECK-THIN-O0-NEXT: Running pass: AlwaysInlinerPass
 // CHECK-THIN-O0-NEXT: Running analysis: ProfileSummaryAnalysis
 // CHECK-THIN-O0-NEXT: Running pass: CoroConditionalWrapper
+// CHECK-THIN-O0-NEXT: Running pass: AllocTokenPass
+// CHECK-THIN-O0-NEXT: Running analysis: OptimizationRemarkEmitterAnalysis
+// CHECK-THIN-O0-NEXT: Running analysis: TargetLibraryAnalysis
 // CHECK-THIN-O0-NEXT: Running pass: CanonicalizeAliasesPass
 // CHECK-THIN-O0-NEXT: Running pass: NameAnonGlobalPass
 // CHECK-THIN-O0-NEXT: Running pass: AnnotationRemarksPass
-// CHECK-THIN-O0-NEXT: Running analysis: TargetLibraryAnalysis
 // CHECK-THIN-O0-NEXT: Running pass: VerifierPass
 // CHECK-THIN-O0-NEXT: Running pass: ThinLTOBitcodeWriterPass
 
diff --git a/clang/test/CodeGenCXX/alloc-token-builtin.cpp b/clang/test/CodeGenCXX/alloc-token-builtin.cpp
new file mode 100644
index 0000000000000..2d0b8e1666faf
--- /dev/null
+++ b/clang/test/CodeGenCXX/alloc-token-builtin.cpp
@@ -0,0 +1,77 @@
+// To test IR generation of the builtin without evaluating the LLVM intrinsic,
+// we set the mode to a stateful mode, which prohibits constant evaluation.
+// RUN: %clang_cc1 -triple x86_64-linux-gnu -Werror -std=c++20 -emit-llvm -falloc-token-mode=random -disable-llvm-passes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-CODEGEN
+// RUN: %clang_cc1 -triple x86_64-linux-gnu -Werror -std=c++20 -emit-llvm -falloc-token-max=2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-LOWER
+
+extern "C" void *my_malloc(unsigned long, unsigned long);
+
+struct NoPtr {
+  int x;
+  long y;
+};
+
+struct WithPtr {
+  int a;
+  char *buf;
+};
+
+int unevaluated_fn();
+
+// CHECK-LABEL: @_Z16test_builtin_intv(
+// CHECK-CODEGEN: call i64 @llvm.alloc.token.id.i64(metadata ![[META_INT:[0-9]+]])
+// CHECK-LOWER: ret i64 0
+unsigned long test_builtin_int() {
+  return __builtin_infer_alloc_token(sizeof(1));
+}
+
+// CHECK-LABEL: @_Z16test_builtin_ptrv(
+// CHECK-CODEGEN: call i64 @llvm.alloc.token.id.i64(metadata ![[META_PTR:[0-9]+]])
+// CHECK-LOWER: ret i64 1
+unsigned long test_builtin_ptr() {
+  return __builtin_infer_alloc_token(sizeof(int *));
+}
+
+// CHECK-LABEL: @_Z25test_builtin_struct_noptrv(
+// CHECK-CODEGEN: call i64 @llvm.alloc.token.id.i64(metadata ![[META_NOPTR:[0-9]+]])
+// CHECK-LOWER: ret i64 0
+unsigned long test_builtin_struct_noptr() {
+  return __builtin_infer_alloc_token(sizeof(NoPtr));
+}
+
+// CHECK-LABEL: @_Z25test_builtin_struct_w_ptrv(
+// CHECK-CODEGEN: call i64 @llvm.alloc.token.id.i64(metadata ![[META_WITHPTR:[0-9]+]])
+// CHECK-LOWER: ret i64 1
+unsigned long test_builtin_struct_w_ptr() {
+  return __builtin_infer_alloc_token(sizeof(WithPtr), 123);
+}
+
+// CHECK-LABEL: @_Z24test_builtin_unevaluatedv(
+// CHECK-NOT: call{{.*}}unevaluated_fn
+// CHECK-CODEGEN: call i64 @llvm.alloc.token.id.i64(metadata ![[META_INT:[0-9]+]])
+// CHECK-LOWER: ret i64 0
+unsigned long test_builtin_unevaluated() {
+	return __builtin_infer_alloc_token(sizeof(int) * unevaluated_fn());
+}
+
+// CHECK-LABEL: @_Z36test_builtin_unsequenced_unevaluatedi(
+// CHECK:     add nsw
+// CHECK-NOT: add nsw
+// CHECK-CODEGEN: %[[REG:[0-9]+]] = call i64 @llvm.alloc.token.id.i64(metadata ![[META_UNKNOWN:[0-9]+]])
+// CHECK-CODEGEN: call{{.*}}@my_malloc({{.*}}, i64 noundef %[[REG]])
+// CHECK-LOWER: call{{.*}}@my_malloc({{.*}}, i64 noundef 0)
+void test_builtin_unsequenced_unevaluated(int x) {
+  my_malloc(++x, __builtin_infer_alloc_token(++x));
+}
+
+// CHECK-LABEL: @_Z20test_builtin_unknownv(
+// CHECK-CODEGEN: call i64 @llvm.alloc.token.id.i64(metadata ![[META_UNKNOWN:[0-9]+]])
+// CHECK-LOWER: ret i64 0
+unsigned long test_builtin_unknown() {
+  return __builtin_infer_alloc_token(4096);
+}
+
+// CHECK-CODEGEN: ![[META_INT]] = !{!"int", i1 false}
+// CHECK-CODEGEN: ![[META_PTR]] = !{!"int *", i1 true}
+// CHECK-CODEGEN: ![[META_NOPTR]] = !{!"NoPtr", i1 false}
+// CHECK-CODEGEN: ![[META_WITHPTR]] = !{!"WithPtr", i1 true}
+// CHECK-CODEGEN: ![[META_UNKNOWN]] = !{}

>From 3c1e721f2bc566b97f29157616ac81354cef7a3e Mon Sep 17 00:00:00 2001
From: Marco Elver <elver at google.com>
Date: Fri, 17 Oct 2025 20:14:51 +0200
Subject: [PATCH 2/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20introduced=20through=20rebase?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.8-beta.1

[skip ci]
---
 .../include/clang/Basic/DiagnosticASTKinds.td |  6 ++++
 clang/lib/AST/ByteCode/InterpBuiltin.cpp      | 27 +++++++-------
 clang/lib/AST/ExprConstant.cpp                |  9 ++---
 clang/lib/AST/InferAlloc.cpp                  | 35 +++++++++----------
 clang/lib/Frontend/CompilerInvocation.cpp     |  9 +----
 clang/test/SemaCXX/alloc-token.cpp            | 13 +++----
 llvm/include/llvm/Support/AllocToken.h        | 12 +++++--
 llvm/lib/Passes/PassBuilder.cpp               |  9 +----
 llvm/lib/Support/AllocToken.cpp               | 35 +++++++++++++------
 .../Transforms/Instrumentation/AllocToken.cpp | 12 +++----
 10 files changed, 89 insertions(+), 78 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticASTKinds.td b/clang/include/clang/Basic/DiagnosticASTKinds.td
index 0be9146f70364..5c462f9646b3b 100644
--- a/clang/include/clang/Basic/DiagnosticASTKinds.td
+++ b/clang/include/clang/Basic/DiagnosticASTKinds.td
@@ -403,6 +403,12 @@ def note_constexpr_assumption_failed : Note<
 def note_constexpr_countzeroes_zero : Note<
   "evaluation of %select{__builtin_elementwise_clzg|__builtin_elementwise_ctzg}0 "
   "with a zero value is undefined">;
+def note_constexpr_infer_alloc_token_type_inference_failed : Note<
+  "could not infer allocation type for __builtin_infer_alloc_token">;
+def note_constexpr_infer_alloc_token_no_metadata : Note<
+  "could not get token metadata for inferred type">;
+def note_constexpr_infer_alloc_token_stateful_mode : Note<
+  "stateful alloc token mode not supported in constexpr">;
 def err_experimental_clang_interp_failed : Error<
   "the experimental clang interpreter failed to evaluate an expression">;
 
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index 7b3670ef46f0e..ca91a58bd4d17 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -1311,12 +1311,12 @@ interp__builtin_ptrauth_string_discriminator(InterpState &S, CodePtr OpPC,
 static bool interp__builtin_infer_alloc_token(InterpState &S, CodePtr OpPC,
                                               const InterpFrame *Frame,
                                               const CallExpr *Call) {
-  const ASTContext &Ctx = S.getASTContext();
-  const uint64_t BitWidth = Ctx.getTypeSize(Ctx.getSizeType());
-  const auto Mode =
-      Ctx.getLangOpts().AllocTokenMode.value_or(llvm::DefaultAllocTokenMode);
-  const uint64_t MaxTokens =
-      Ctx.getLangOpts().AllocTokenMax.value_or(~0ULL >> (64 - BitWidth));
+  const ASTContext &ASTCtx = S.getASTContext();
+  uint64_t BitWidth = ASTCtx.getTypeSize(ASTCtx.getSizeType());
+  auto Mode =
+      ASTCtx.getLangOpts().AllocTokenMode.value_or(llvm::DefaultAllocTokenMode);
+  uint64_t MaxTokens =
+      ASTCtx.getLangOpts().AllocTokenMax.value_or(~0ULL >> (64 - BitWidth));
 
   // We do not read any of the arguments; discard them.
   for (int I = Call->getNumArgs() - 1; I >= 0; --I)
@@ -1324,25 +1324,26 @@ static bool interp__builtin_infer_alloc_token(InterpState &S, CodePtr OpPC,
 
   // Note: Type inference from a surrounding cast is not supported in
   // constexpr evaluation.
-  QualType AllocType = infer_alloc::inferPossibleType(Call, Ctx, nullptr);
+  QualType AllocType = infer_alloc::inferPossibleType(Call, ASTCtx, nullptr);
   if (AllocType.isNull()) {
-    S.CCEDiag(Call) << "could not infer allocation type";
+    S.CCEDiag(Call,
+              diag::note_constexpr_infer_alloc_token_type_inference_failed);
     return false;
   }
 
-  auto ATMD = infer_alloc::getAllocTokenMetadata(AllocType, Ctx);
+  auto ATMD = infer_alloc::getAllocTokenMetadata(AllocType, ASTCtx);
   if (!ATMD) {
-    S.CCEDiag(Call) << "could not get token metadata for type";
+    S.CCEDiag(Call, diag::note_constexpr_infer_alloc_token_no_metadata);
     return false;
   }
 
-  auto MaybeToken = llvm::getAllocTokenHash(Mode, *ATMD, MaxTokens);
+  auto MaybeToken = llvm::getAllocToken(Mode, *ATMD, MaxTokens);
   if (!MaybeToken) {
-    S.CCEDiag(Call) << "stateful alloc token mode not supported in constexpr";
+    S.CCEDiag(Call, diag::note_constexpr_infer_alloc_token_stateful_mode);
     return false;
   }
 
-  pushInteger(S, llvm::APInt(BitWidth, *MaybeToken), Ctx.getSizeType());
+  pushInteger(S, llvm::APInt(BitWidth, *MaybeToken), ASTCtx.getSizeType());
   return true;
 }
 
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 137fdb6f0c82b..6f9c06dd50a5a 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -14421,18 +14421,19 @@ bool IntExprEvaluator::VisitBuiltinCallExpr(const CallExpr *E,
     // can be checked with __builtin_constant_p(...).
     QualType AllocType = infer_alloc::inferPossibleType(E, Info.Ctx, nullptr);
     if (AllocType.isNull())
-      return Error(E);
+      return Error(
+          E, diag::note_constexpr_infer_alloc_token_type_inference_failed);
     auto ATMD = infer_alloc::getAllocTokenMetadata(AllocType, Info.Ctx);
     if (!ATMD)
-      return Error(E);
+      return Error(E, diag::note_constexpr_infer_alloc_token_no_metadata);
     auto Mode =
         Info.getLangOpts().AllocTokenMode.value_or(llvm::DefaultAllocTokenMode);
     uint64_t BitWidth = Info.Ctx.getTypeSize(Info.Ctx.getSizeType());
     uint64_t MaxTokens =
         Info.getLangOpts().AllocTokenMax.value_or(~0ULL >> (64 - BitWidth));
-    auto MaybeToken = llvm::getAllocTokenHash(Mode, *ATMD, MaxTokens);
+    auto MaybeToken = llvm::getAllocToken(Mode, *ATMD, MaxTokens);
     if (!MaybeToken)
-      return Error(E);
+      return Error(E, diag::note_constexpr_infer_alloc_token_stateful_mode);
     return Success(llvm::APInt(BitWidth, *MaybeToken), E);
   }
 
diff --git a/clang/lib/AST/InferAlloc.cpp b/clang/lib/AST/InferAlloc.cpp
index c21fcfccaef0f..3ec55c26ac366 100644
--- a/clang/lib/AST/InferAlloc.cpp
+++ b/clang/lib/AST/InferAlloc.cpp
@@ -19,11 +19,13 @@
 #include "clang/Basic/IdentifierTable.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
-namespace clang {
-namespace {
-bool typeContainsPointer(QualType T,
-                         llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
-                         bool &IncompleteType) {
+using namespace clang;
+using namespace infer_alloc;
+
+static bool
+typeContainsPointer(QualType T,
+                    llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
+                    bool &IncompleteType) {
   QualType CanonicalType = T.getCanonicalType();
   if (CanonicalType->isPointerType())
     return true; // base case
@@ -70,7 +72,7 @@ bool typeContainsPointer(QualType T,
 }
 
 /// Infer type from a simple sizeof expression.
-QualType inferTypeFromSizeofExpr(const Expr *E) {
+static QualType inferTypeFromSizeofExpr(const Expr *E) {
   const Expr *Arg = E->IgnoreParenImpCasts();
   if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
     if (UET->getKind() == UETT_SizeOf) {
@@ -96,7 +98,7 @@ QualType inferTypeFromSizeofExpr(const Expr *E) {
 ///
 ///   malloc(sizeof(HasFlexArray) + sizeof(int) * 32);  // infers 'HasFlexArray'
 ///
-QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
+static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
   const Expr *Arg = E->IgnoreParenImpCasts();
   // The argument is a lone sizeof expression.
   if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
@@ -132,7 +134,7 @@ QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
 ///   size_t my_size = sizeof(MyType);
 ///   void *x = malloc(my_size);  // infers 'MyType'
 ///
-QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
+static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
   const Expr *Arg = E->IgnoreParenImpCasts();
   if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
     if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
@@ -148,8 +150,8 @@ QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
 ///
 ///   MyType *x = (MyType *)malloc(4096);  // infers 'MyType'
 ///
-QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
-                                       const CastExpr *CastE) {
+static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
+                                              const CastExpr *CastE) {
   if (!CastE)
     return QualType();
   QualType PtrType = CastE->getType();
@@ -157,12 +159,10 @@ QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
     return PtrType->getPointeeType();
   return QualType();
 }
-} // anonymous namespace
-
-namespace infer_alloc {
 
-QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
-                           const CastExpr *CastE) {
+QualType clang::infer_alloc::inferPossibleType(const CallExpr *E,
+                                               const ASTContext &Ctx,
+                                               const CastExpr *CastE) {
   QualType AllocType;
   // First check arguments.
   for (const Expr *Arg : E->arguments()) {
@@ -179,7 +179,7 @@ QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
 }
 
 std::optional<llvm::AllocTokenMetadata>
-getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
+clang::infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
   llvm::AllocTokenMetadata ATMD;
 
   // Get unique type name.
@@ -199,6 +199,3 @@ getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
 
   return ATMD;
 }
-
-} // namespace infer_alloc
-} // namespace clang
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index 9ce1df728336e..85ba85099500a 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -4565,14 +4565,7 @@ bool CompilerInvocation::ParseLangArgs(LangOptions &Opts, ArgList &Args,
 
   if (const auto *Arg = Args.getLastArg(options::OPT_falloc_token_mode_EQ)) {
     StringRef S = Arg->getValue();
-    auto Mode = llvm::StringSwitch<std::optional<llvm::AllocTokenMode>>(S)
-                    .Case("increment", llvm::AllocTokenMode::Increment)
-                    .Case("random", llvm::AllocTokenMode::Random)
-                    .Case("typehash", llvm::AllocTokenMode::TypeHash)
-                    .Case("typehashpointersplit",
-                          llvm::AllocTokenMode::TypeHashPointerSplit)
-                    .Default(std::nullopt);
-    if (Mode)
+    if (auto Mode = getAllocTokenModeFromString(S))
       Opts.AllocTokenMode = Mode;
     else
       Diags.Report(diag::err_drv_invalid_value) << Arg->getAsString(Args) << S;
diff --git a/clang/test/SemaCXX/alloc-token.cpp b/clang/test/SemaCXX/alloc-token.cpp
index 4956f517b708b..79c6f00856621 100644
--- a/clang/test/SemaCXX/alloc-token.cpp
+++ b/clang/test/SemaCXX/alloc-token.cpp
@@ -7,10 +7,6 @@
 #error "missing __builtin_infer_alloc_token"
 #endif
 
-#ifndef TOKEN_MAX
-#define TOKEN_MAX 0
-#endif
-
 struct NoPtr {
   int x;
   long y;
@@ -27,11 +23,15 @@ static_assert(__builtin_infer_alloc_token(sizeof(int)) == 2689373973731826898ULL
 static_assert(__builtin_infer_alloc_token(sizeof(char*)) == 2250492667400517147ULL);
 static_assert(__builtin_infer_alloc_token(sizeof(NoPtr)) == 7465259095297095368ULL);
 static_assert(__builtin_infer_alloc_token(sizeof(WithPtr)) == 11898882936532569145ULL);
-#elif TOKEN_MAX == 2
+#elif defined(TOKEN_MAX)
+#  if TOKEN_MAX == 2
 static_assert(__builtin_infer_alloc_token(sizeof(int)) == 0);
 static_assert(__builtin_infer_alloc_token(sizeof(char*)) == 1);
 static_assert(__builtin_infer_alloc_token(sizeof(NoPtr)) == 0);
 static_assert(__builtin_infer_alloc_token(sizeof(WithPtr)) == 1);
+#  else
+#    error "unhandled TOKEN_MAX case"
+#  endif
 #else
 static_assert(__builtin_infer_alloc_token(sizeof(int)) == 2689373973731826898ULL);
 static_assert(__builtin_infer_alloc_token(sizeof(char*)) == 11473864704255292954ULL);
@@ -53,5 +53,6 @@ static_assert(__builtin_infer_alloc_token(sizeof(NoPtr) << 8) == get_token<NoPtr
 void negative_tests() {
   __builtin_infer_alloc_token(); // expected-error {{too few arguments to function call}}
   __builtin_infer_alloc_token((void)0); // expected-error {{argument may not have 'void' type}}
-  constexpr auto inference_fail = __builtin_infer_alloc_token(123); // expected-error {{must be initialized by a constant expression}}
+  constexpr auto inference_fail = __builtin_infer_alloc_token(123); // expected-error {{must be initialized by a constant expression}} \
+                                                                    // expected-note {{could not infer allocation type for __builtin_infer_alloc_token}}
 }
diff --git a/llvm/include/llvm/Support/AllocToken.h b/llvm/include/llvm/Support/AllocToken.h
index 48db026957443..e40d8163a9d7c 100644
--- a/llvm/include/llvm/Support/AllocToken.h
+++ b/llvm/include/llvm/Support/AllocToken.h
@@ -14,6 +14,7 @@
 #define LLVM_SUPPORT_ALLOCTOKEN_H
 
 #include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringRef.h"
 #include <cstdint>
 #include <optional>
 
@@ -40,6 +41,11 @@ enum class AllocTokenMode {
 inline constexpr AllocTokenMode DefaultAllocTokenMode =
     AllocTokenMode::TypeHashPointerSplit;
 
+/// Returns the AllocTokenMode from its canonical string name; if an invalid
+/// name was provided returns nullopt.
+LLVM_ABI std::optional<AllocTokenMode>
+getAllocTokenModeFromString(StringRef Name);
+
 /// Metadata about an allocation used to generate a token ID.
 struct AllocTokenMetadata {
   SmallString<64> TypeName;
@@ -53,9 +59,9 @@ struct AllocTokenMetadata {
 /// \param Metadata The metadata about the allocation.
 /// \param MaxTokens The maximum number of tokens (must not be 0)
 /// \return The calculated allocation token ID, or std::nullopt.
-std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
-                                          const AllocTokenMetadata &Metadata,
-                                          uint64_t MaxTokens);
+LLVM_ABI std::optional<uint64_t>
+getAllocToken(AllocTokenMode Mode, const AllocTokenMetadata &Metadata,
+              uint64_t MaxTokens);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index c3522a38eb9e1..4cebb0bb32e4e 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -1102,14 +1102,7 @@ Expected<AllocTokenOptions> parseAllocTokenPassOptions(StringRef Params) {
     std::tie(ParamName, Params) = Params.split(';');
 
     if (ParamName.consume_front("mode=")) {
-      auto Mode = StringSwitch<std::optional<AllocTokenMode>>(ParamName)
-                      .Case("increment", AllocTokenMode::Increment)
-                      .Case("random", AllocTokenMode::Random)
-                      .Case("typehash", AllocTokenMode::TypeHash)
-                      .Case("typehashpointersplit",
-                            AllocTokenMode::TypeHashPointerSplit)
-                      .Default(std::nullopt);
-      if (Mode)
+      if (auto Mode = getAllocTokenModeFromString(ParamName))
         Result.Mode = *Mode;
       else
         return make_error<StringError>(
diff --git a/llvm/lib/Support/AllocToken.cpp b/llvm/lib/Support/AllocToken.cpp
index 6c6f80ac4997c..8e9e89f0df353 100644
--- a/llvm/lib/Support/AllocToken.cpp
+++ b/llvm/lib/Support/AllocToken.cpp
@@ -11,14 +11,31 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Support/AllocToken.h"
+#include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/SipHash.h"
 
-namespace llvm {
-std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
-                                          const AllocTokenMetadata &Metadata,
-                                          uint64_t MaxTokens) {
-  assert(MaxTokens && "Must provide concrete max tokens");
+using namespace llvm;
+
+std::optional<AllocTokenMode>
+llvm::getAllocTokenModeFromString(StringRef Name) {
+  return StringSwitch<std::optional<AllocTokenMode>>(Name)
+      .Case("increment", AllocTokenMode::Increment)
+      .Case("random", AllocTokenMode::Random)
+      .Case("typehash", AllocTokenMode::TypeHash)
+      .Case("typehashpointersplit", AllocTokenMode::TypeHashPointerSplit)
+      .Default(std::nullopt);
+}
+
+static uint64_t getStableHash(const AllocTokenMetadata &Metadata,
+                              uint64_t MaxTokens) {
+  return getStableSipHash(Metadata.TypeName) % MaxTokens;
+}
+
+std::optional<uint64_t> llvm::getAllocToken(AllocTokenMode Mode,
+                                            const AllocTokenMetadata &Metadata,
+                                            uint64_t MaxTokens) {
+  assert(MaxTokens && "Must provide non-zero max tokens");
 
   switch (Mode) {
   case AllocTokenMode::Increment:
@@ -26,15 +43,14 @@ std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
     // Stateful modes cannot be implemented as a pure function.
     return std::nullopt;
 
-  case AllocTokenMode::TypeHash: {
-    return getStableSipHash(Metadata.TypeName) % MaxTokens;
-  }
+  case AllocTokenMode::TypeHash:
+    return getStableHash(Metadata, MaxTokens);
 
   case AllocTokenMode::TypeHashPointerSplit: {
     if (MaxTokens == 1)
       return 0;
     const uint64_t HalfTokens = MaxTokens / 2;
-    uint64_t Hash = getStableSipHash(Metadata.TypeName) % HalfTokens;
+    uint64_t Hash = getStableHash(Metadata, HalfTokens);
     if (Metadata.ContainsPointer)
       Hash += HalfTokens;
     return Hash;
@@ -43,4 +59,3 @@ std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
 
   llvm_unreachable("");
 }
-} // namespace llvm
diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
index bfda56b1f746d..8181e4ef1d74f 100644
--- a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
@@ -189,8 +189,7 @@ class TypeHashMode : public ModeBase {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
       AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
-      if (auto Token =
-              getAllocTokenHash(TokenMode::TypeHash, Metadata, MaxTokens))
+      if (auto Token = getAllocToken(TokenMode::TypeHash, Metadata, MaxTokens))
         return *Token;
     }
     // Fallback.
@@ -222,8 +221,8 @@ class TypeHashPointerSplitMode : public TypeHashMode {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
       AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
-      if (auto Token = getAllocTokenHash(TokenMode::TypeHashPointerSplit,
-                                         Metadata, MaxTokens))
+      if (auto Token = getAllocToken(TokenMode::TypeHashPointerSplit, Metadata,
+                                     MaxTokens))
         return *Token;
     }
     // Pick the fallback token (ClFallbackToken), which by default is 0, meaning
@@ -357,9 +356,8 @@ bool AllocToken::instrumentFunction(Function &F) {
   }
 
   if (!IntrinsicInsts.empty()) {
-    for (auto *II : IntrinsicInsts) {
+    for (auto *II : IntrinsicInsts)
       replaceIntrinsicInst(II, ORE);
-    }
     Modified = true;
     NumFunctionsModified++;
   }
@@ -381,7 +379,7 @@ AllocToken::shouldInstrumentCall(const CallBase &CB,
   if (TLI.getLibFunc(*Callee, Func)) {
     if (isInstrumentableLibFunc(Func, CB, TLI))
       return Func;
-  } else if (Options.Extended && getAllocTokenMetadata(CB)) {
+  } else if (Options.Extended && CB.getMetadata(LLVMContext::MD_alloc_token)) {
     return NotLibFunc;
   }
 



More information about the llvm-commits mailing list