[clang] [llvm] [Clang][Sema] Add __builtin_infer_alloc_token() declaration and semantic checks (PR #163638)
    Marco Elver via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Thu Oct 23 02:25:14 PDT 2025
    
    
  
https://github.com/melver updated https://github.com/llvm/llvm-project/pull/163638
>From 052ff02f6a1e9fd3955da128bc72a7fda9cfd0c7 Mon Sep 17 00:00:00 2001
From: Marco Elver <elver at google.com>
Date: Wed, 15 Oct 2025 23:29:12 +0200
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20to=20main=20this=20commit=20is=20based=20on?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.8-beta.1
[skip ci]
---
 clang/docs/AllocToken.rst                     |   4 +-
 clang/include/clang/AST/InferAlloc.h          |  35 +++
 clang/include/clang/Basic/CodeGenOptions.h    |   4 -
 clang/include/clang/Basic/LangOptions.h       |   8 +
 clang/include/clang/Driver/Options.td         |   4 +
 clang/lib/AST/CMakeLists.txt                  |   1 +
 clang/lib/AST/InferAlloc.cpp                  | 204 ++++++++++++++++++
 clang/lib/CodeGen/BackendUtil.cpp             |   9 +-
 clang/lib/CodeGen/CGExpr.cpp                  | 198 ++---------------
 clang/lib/CodeGen/CodeGenFunction.h           |   9 +-
 clang/lib/Frontend/CompilerInvocation.cpp     |  60 ++++--
 clang/test/Driver/fsanitize-alloc-token.c     |  11 +
 llvm/include/llvm/IR/Intrinsics.td            |   8 +
 llvm/include/llvm/Support/AllocToken.h        |  62 ++++++
 .../Transforms/Instrumentation/AllocToken.h   |   2 +
 llvm/lib/Passes/PassBuilder.cpp               |  32 +++
 llvm/lib/Passes/PassRegistry.def              |   5 +-
 llvm/lib/Support/AllocToken.cpp               |  46 ++++
 llvm/lib/Support/CMakeLists.txt               |   1 +
 .../Transforms/Instrumentation/AllocToken.cpp | 148 +++++++------
 llvm/test/Instrumentation/AllocToken/basic.ll |   2 +-
 .../Instrumentation/AllocToken/basic32.ll     |   2 +-
 llvm/test/Instrumentation/AllocToken/fast.ll  |   2 +-
 .../Instrumentation/AllocToken/intrinsic.ll   |  32 +++
 .../Instrumentation/AllocToken/intrinsic32.ll |  32 +++
 .../test/Instrumentation/AllocToken/invoke.ll |   2 +-
 .../Instrumentation/AllocToken/nonlibcalls.ll |   2 +-
 .../AllocToken/typehashpointersplit.ll        |   2 +-
 .../utils/gn/secondary/clang/lib/AST/BUILD.gn |   1 +
 .../gn/secondary/llvm/lib/Support/BUILD.gn    |   1 +
 30 files changed, 653 insertions(+), 276 deletions(-)
 create mode 100644 clang/include/clang/AST/InferAlloc.h
 create mode 100644 clang/lib/AST/InferAlloc.cpp
 create mode 100644 llvm/include/llvm/Support/AllocToken.h
 create mode 100644 llvm/lib/Support/AllocToken.cpp
 create mode 100644 llvm/test/Instrumentation/AllocToken/intrinsic.ll
 create mode 100644 llvm/test/Instrumentation/AllocToken/intrinsic32.ll
diff --git a/clang/docs/AllocToken.rst b/clang/docs/AllocToken.rst
index bda84669456ce..b65e18ccfa967 100644
--- a/clang/docs/AllocToken.rst
+++ b/clang/docs/AllocToken.rst
@@ -37,8 +37,8 @@ The default mode to calculate tokens is:
   pointers.
 
 Other token ID assignment modes are supported, but they may be subject to
-change or removal. These may (experimentally) be selected with ``-mllvm
--alloc-token-mode=<mode>``:
+change or removal. These may (experimentally) be selected with ``-Xclang
+-falloc-token-mode=<mode>``:
 
 * ``typehash``: This mode assigns a token ID based on the hash of the allocated
   type's name.
diff --git a/clang/include/clang/AST/InferAlloc.h b/clang/include/clang/AST/InferAlloc.h
new file mode 100644
index 0000000000000..c3dc30204feaf
--- /dev/null
+++ b/clang/include/clang/AST/InferAlloc.h
@@ -0,0 +1,35 @@
+//===--- InferAlloc.h - Allocation type inference ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for allocation-related type inference.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_INFERALLOC_H
+#define LLVM_CLANG_AST_INFERALLOC_H
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Expr.h"
+#include "llvm/Support/AllocToken.h"
+#include <optional>
+
+namespace clang {
+namespace infer_alloc {
+
+/// Infer the possible allocated type from an allocation call expression.
+QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
+                           const CastExpr *CastE);
+
+/// Get the information required for construction of an allocation token ID.
+std::optional<llvm::AllocTokenMetadata>
+getAllocTokenMetadata(QualType T, const ASTContext &Ctx);
+
+} // namespace infer_alloc
+} // namespace clang
+
+#endif // LLVM_CLANG_AST_INFERALLOC_H
diff --git a/clang/include/clang/Basic/CodeGenOptions.h b/clang/include/clang/Basic/CodeGenOptions.h
index cae06c3c9495a..5d5cf250b56b9 100644
--- a/clang/include/clang/Basic/CodeGenOptions.h
+++ b/clang/include/clang/Basic/CodeGenOptions.h
@@ -447,10 +447,6 @@ class CodeGenOptions : public CodeGenOptionsBase {
 
   std::optional<double> AllowRuntimeCheckSkipHotCutoff;
 
-  /// Maximum number of allocation tokens (0 = no max), nullopt if none set (use
-  /// pass default).
-  std::optional<uint64_t> AllocTokenMax;
-
   /// List of backend command-line options for -fembed-bitcode.
   std::vector<uint8_t> CmdArgs;
 
diff --git a/clang/include/clang/Basic/LangOptions.h b/clang/include/clang/Basic/LangOptions.h
index 41595ec2a060d..83becb73076f9 100644
--- a/clang/include/clang/Basic/LangOptions.h
+++ b/clang/include/clang/Basic/LangOptions.h
@@ -25,6 +25,7 @@
 #include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/BinaryFormat/DXContainer.h"
+#include "llvm/Support/AllocToken.h"
 #include "llvm/TargetParser/Triple.h"
 #include <optional>
 #include <string>
@@ -565,6 +566,13 @@ class LangOptions : public LangOptionsBase {
   bool AtomicFineGrainedMemory = false;
   bool AtomicIgnoreDenormalMode = false;
 
+  /// Maximum number of allocation tokens (0 = no max), nullopt if none set (use
+  /// target default).
+  std::optional<uint64_t> AllocTokenMax;
+
+  /// The allocation token mode.
+  std::optional<llvm::AllocTokenMode> AllocTokenMode;
+
   LangOptions();
 
   /// Set language defaults for the given input language and
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 611b68e5281f0..370fedd545fec 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -2751,6 +2751,10 @@ def falloc_token_max_EQ : Joined<["-"], "falloc-token-max=">,
   MetaVarName<"<N>">,
   HelpText<"Limit to maximum N allocation tokens (0 = no max)">;
 
+def falloc_token_mode_EQ : Joined<["-"], "falloc-token-mode=">,
+  Group<f_Group>, Visibility<[CC1Option]>,
+  HelpText<"Set the allocation token mode (experimental)">;
+
 def fallow_runtime_check_skip_hot_cutoff_EQ
     : Joined<["-"], "fallow-runtime-check-skip-hot-cutoff=">,
       Group<f_clang_Group>,
diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt
index d4fd7a7f16d53..fd50e956bb865 100644
--- a/clang/lib/AST/CMakeLists.txt
+++ b/clang/lib/AST/CMakeLists.txt
@@ -66,6 +66,7 @@ add_clang_library(clangAST
   ExternalASTMerger.cpp
   ExternalASTSource.cpp
   FormatString.cpp
+  InferAlloc.cpp
   InheritViz.cpp
   ByteCode/BitcastBuffer.cpp
   ByteCode/ByteCodeEmitter.cpp
diff --git a/clang/lib/AST/InferAlloc.cpp b/clang/lib/AST/InferAlloc.cpp
new file mode 100644
index 0000000000000..c21fcfccaef0f
--- /dev/null
+++ b/clang/lib/AST/InferAlloc.cpp
@@ -0,0 +1,204 @@
+//===--- InferAlloc.cpp - Allocation type inference -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements allocation-related type inference.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/InferAlloc.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/IdentifierTable.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+namespace clang {
+namespace {
+bool typeContainsPointer(QualType T,
+                         llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
+                         bool &IncompleteType) {
+  QualType CanonicalType = T.getCanonicalType();
+  if (CanonicalType->isPointerType())
+    return true; // base case
+
+  // Look through typedef chain to check for special types.
+  for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
+       CurrentT = TT->getDecl()->getUnderlyingType()) {
+    const IdentifierInfo *II = TT->getDecl()->getIdentifier();
+    // Special Case: Syntactically uintptr_t is not a pointer; semantically,
+    // however, very likely used as such. Therefore, classify uintptr_t as a
+    // pointer, too.
+    if (II && II->isStr("uintptr_t"))
+      return true;
+  }
+
+  // The type is an array; check the element type.
+  if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType))
+    return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType);
+  // The type is a struct, class, or union.
+  if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
+    if (!RD->isCompleteDefinition()) {
+      IncompleteType = true;
+      return false;
+    }
+    if (!VisitedRD.insert(RD).second)
+      return false; // already visited
+    // Check all fields.
+    for (const FieldDecl *Field : RD->fields()) {
+      if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType))
+        return true;
+    }
+    // For C++ classes, also check base classes.
+    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+      // Polymorphic types require a vptr.
+      if (CXXRD->isDynamicClass())
+        return true;
+      for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
+        if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType))
+          return true;
+      }
+    }
+  }
+  return false;
+}
+
+/// Infer type from a simple sizeof expression.
+QualType inferTypeFromSizeofExpr(const Expr *E) {
+  const Expr *Arg = E->IgnoreParenImpCasts();
+  if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
+    if (UET->getKind() == UETT_SizeOf) {
+      if (UET->isArgumentType())
+        return UET->getArgumentTypeInfo()->getType();
+      else
+        return UET->getArgumentExpr()->getType();
+    }
+  }
+  return QualType();
+}
+
+/// Infer type from an arithmetic expression involving a sizeof. For example:
+///
+///   malloc(sizeof(MyType) + padding);  // infers 'MyType'
+///   malloc(sizeof(MyType) * 32);       // infers 'MyType'
+///   malloc(32 * sizeof(MyType));       // infers 'MyType'
+///   malloc(sizeof(MyType) << 1);       // infers 'MyType'
+///   ...
+///
+/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
+/// when considering allocations for structs with flexible array members:
+///
+///   malloc(sizeof(HasFlexArray) + sizeof(int) * 32);  // infers 'HasFlexArray'
+///
+QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
+  const Expr *Arg = E->IgnoreParenImpCasts();
+  // The argument is a lone sizeof expression.
+  if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
+    return T;
+  if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) {
+    // Argument is an arithmetic expression. Cover common arithmetic patterns
+    // involving sizeof.
+    switch (BO->getOpcode()) {
+    case BO_Add:
+    case BO_Div:
+    case BO_Mul:
+    case BO_Shl:
+    case BO_Shr:
+    case BO_Sub:
+      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS());
+          !T.isNull())
+        return T;
+      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS());
+          !T.isNull())
+        return T;
+      break;
+    default:
+      break;
+    }
+  }
+  return QualType();
+}
+
+/// If the expression E is a reference to a variable, infer the type from a
+/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
+/// and ignores if a variable is later reassigned. For example:
+///
+///   size_t my_size = sizeof(MyType);
+///   void *x = malloc(my_size);  // infers 'MyType'
+///
+QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
+  const Expr *Arg = E->IgnoreParenImpCasts();
+  if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
+    if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
+      if (const Expr *Init = VD->getInit())
+        return inferPossibleTypeFromArithSizeofExpr(Init);
+    }
+  }
+  return QualType();
+}
+
+/// Deduces the allocated type by checking if the allocation call's result
+/// is immediately used in a cast expression. For example:
+///
+///   MyType *x = (MyType *)malloc(4096);  // infers 'MyType'
+///
+QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
+                                       const CastExpr *CastE) {
+  if (!CastE)
+    return QualType();
+  QualType PtrType = CastE->getType();
+  if (PtrType->isPointerType())
+    return PtrType->getPointeeType();
+  return QualType();
+}
+} // anonymous namespace
+
+namespace infer_alloc {
+
+QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
+                           const CastExpr *CastE) {
+  QualType AllocType;
+  // First check arguments.
+  for (const Expr *Arg : E->arguments()) {
+    AllocType = inferPossibleTypeFromArithSizeofExpr(Arg);
+    if (AllocType.isNull())
+      AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg);
+    if (!AllocType.isNull())
+      break;
+  }
+  // Then check later casts.
+  if (AllocType.isNull())
+    AllocType = inferPossibleTypeFromCastExpr(E, CastE);
+  return AllocType;
+}
+
+std::optional<llvm::AllocTokenMetadata>
+getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
+  llvm::AllocTokenMetadata ATMD;
+
+  // Get unique type name.
+  PrintingPolicy Policy(Ctx.getLangOpts());
+  Policy.SuppressTagKeyword = true;
+  Policy.FullyQualifiedName = true;
+  llvm::raw_svector_ostream TypeNameOS(ATMD.TypeName);
+  T.getCanonicalType().print(TypeNameOS, Policy);
+
+  // Check if QualType contains a pointer. Implements a simple DFS to
+  // recursively check if a type contains a pointer type.
+  llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
+  bool IncompleteType = false;
+  ATMD.ContainsPointer = typeContainsPointer(T, VisitedRD, IncompleteType);
+  if (!ATMD.ContainsPointer && IncompleteType)
+    return std::nullopt;
+
+  return ATMD;
+}
+
+} // namespace infer_alloc
+} // namespace clang
diff --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp
index f8e8086afc36f..23ad11ac9f792 100644
--- a/clang/lib/CodeGen/BackendUtil.cpp
+++ b/clang/lib/CodeGen/BackendUtil.cpp
@@ -234,9 +234,12 @@ class EmitAssemblyHelper {
 };
 } // namespace
 
-static AllocTokenOptions getAllocTokenOptions(const CodeGenOptions &CGOpts) {
+static AllocTokenOptions getAllocTokenOptions(const LangOptions &LangOpts,
+                                              const CodeGenOptions &CGOpts) {
   AllocTokenOptions Opts;
-  Opts.MaxTokens = CGOpts.AllocTokenMax;
+  if (LangOpts.AllocTokenMode)
+    Opts.Mode = *LangOpts.AllocTokenMode;
+  Opts.MaxTokens = LangOpts.AllocTokenMax;
   Opts.Extended = CGOpts.SanitizeAllocTokenExtended;
   Opts.FastABI = CGOpts.SanitizeAllocTokenFastABI;
   return Opts;
@@ -808,7 +811,7 @@ static void addSanitizers(const Triple &TargetTriple,
         // memory allocation function detection.
         MPM.addPass(InferFunctionAttrsPass());
       }
-      MPM.addPass(AllocTokenPass(getAllocTokenOptions(CodeGenOpts)));
+      MPM.addPass(AllocTokenPass(getAllocTokenOptions(LangOpts, CodeGenOpts)));
     }
   };
   if (ClSanitizeOnOptimizerEarlyEP) {
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e8255b0554da8..0a2f57dc119b6 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -29,6 +29,7 @@
 #include "clang/AST/ASTLambda.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/DeclObjC.h"
+#include "clang/AST/InferAlloc.h"
 #include "clang/AST/NSAPI.h"
 #include "clang/AST/ParentMapContext.h"
 #include "clang/AST/StmtVisitor.h"
@@ -1273,194 +1274,39 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound,
   EmitCheck(std::make_pair(Check, CheckKind), CheckHandler, StaticData, Index);
 }
 
-static bool
-typeContainsPointer(QualType T,
-                    llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
-                    bool &IncompleteType) {
-  QualType CanonicalType = T.getCanonicalType();
-  if (CanonicalType->isPointerType())
-    return true; // base case
-
-  // Look through typedef chain to check for special types.
-  for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
-       CurrentT = TT->getDecl()->getUnderlyingType()) {
-    const IdentifierInfo *II = TT->getDecl()->getIdentifier();
-    // Special Case: Syntactically uintptr_t is not a pointer; semantically,
-    // however, very likely used as such. Therefore, classify uintptr_t as a
-    // pointer, too.
-    if (II && II->isStr("uintptr_t"))
-      return true;
-  }
-
-  // The type is an array; check the element type.
-  if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType))
-    return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType);
-  // The type is a struct, class, or union.
-  if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
-    if (!RD->isCompleteDefinition()) {
-      IncompleteType = true;
-      return false;
-    }
-    if (!VisitedRD.insert(RD).second)
-      return false; // already visited
-    // Check all fields.
-    for (const FieldDecl *Field : RD->fields()) {
-      if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType))
-        return true;
-    }
-    // For C++ classes, also check base classes.
-    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
-      // Polymorphic types require a vptr.
-      if (CXXRD->isDynamicClass())
-        return true;
-      for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
-        if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType))
-          return true;
-      }
-    }
-  }
-  return false;
-}
-
-void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) {
-  assert(SanOpts.has(SanitizerKind::AllocToken) &&
-         "Only needed with -fsanitize=alloc-token");
+llvm::MDNode *CodeGenFunction::buildAllocToken(QualType AllocType) {
+  auto ATMD = infer_alloc::getAllocTokenMetadata(AllocType, getContext());
+  if (!ATMD)
+    return nullptr;
 
   llvm::MDBuilder MDB(getLLVMContext());
-
-  // Get unique type name.
-  PrintingPolicy Policy(CGM.getContext().getLangOpts());
-  Policy.SuppressTagKeyword = true;
-  Policy.FullyQualifiedName = true;
-  SmallString<64> TypeName;
-  llvm::raw_svector_ostream TypeNameOS(TypeName);
-  AllocType.getCanonicalType().print(TypeNameOS, Policy);
-  auto *TypeNameMD = MDB.createString(TypeNameOS.str());
-
-  // Check if QualType contains a pointer. Implements a simple DFS to
-  // recursively check if a type contains a pointer type.
-  llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
-  bool IncompleteType = false;
-  const bool ContainsPtr =
-      typeContainsPointer(AllocType, VisitedRD, IncompleteType);
-  if (!ContainsPtr && IncompleteType)
-    return;
-  auto *ContainsPtrC = Builder.getInt1(ContainsPtr);
+  auto *TypeNameMD = MDB.createString(ATMD->TypeName);
+  auto *ContainsPtrC = Builder.getInt1(ATMD->ContainsPointer);
   auto *ContainsPtrMD = MDB.createConstant(ContainsPtrC);
 
   // Format: !{<type-name>, <contains-pointer>}
-  auto *MDN =
-      llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD});
-  CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN);
-}
-
-namespace {
-/// Infer type from a simple sizeof expression.
-QualType inferTypeFromSizeofExpr(const Expr *E) {
-  const Expr *Arg = E->IgnoreParenImpCasts();
-  if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
-    if (UET->getKind() == UETT_SizeOf) {
-      if (UET->isArgumentType())
-        return UET->getArgumentTypeInfo()->getType();
-      else
-        return UET->getArgumentExpr()->getType();
-    }
-  }
-  return QualType();
-}
-
-/// Infer type from an arithmetic expression involving a sizeof. For example:
-///
-///   malloc(sizeof(MyType) + padding);  // infers 'MyType'
-///   malloc(sizeof(MyType) * 32);       // infers 'MyType'
-///   malloc(32 * sizeof(MyType));       // infers 'MyType'
-///   malloc(sizeof(MyType) << 1);       // infers 'MyType'
-///   ...
-///
-/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
-/// when considering allocations for structs with flexible array members:
-///
-///   malloc(sizeof(HasFlexArray) + sizeof(int) * 32);  // infers 'HasFlexArray'
-///
-QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
-  const Expr *Arg = E->IgnoreParenImpCasts();
-  // The argument is a lone sizeof expression.
-  if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
-    return T;
-  if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) {
-    // Argument is an arithmetic expression. Cover common arithmetic patterns
-    // involving sizeof.
-    switch (BO->getOpcode()) {
-    case BO_Add:
-    case BO_Div:
-    case BO_Mul:
-    case BO_Shl:
-    case BO_Shr:
-    case BO_Sub:
-      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS());
-          !T.isNull())
-        return T;
-      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS());
-          !T.isNull())
-        return T;
-      break;
-    default:
-      break;
-    }
-  }
-  return QualType();
+  return llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD});
 }
 
-/// If the expression E is a reference to a variable, infer the type from a
-/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
-/// and ignores if a variable is later reassigned. For example:
-///
-///   size_t my_size = sizeof(MyType);
-///   void *x = malloc(my_size);  // infers 'MyType'
-///
-QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
-  const Expr *Arg = E->IgnoreParenImpCasts();
-  if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
-    if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
-      if (const Expr *Init = VD->getInit())
-        return inferPossibleTypeFromArithSizeofExpr(Init);
-    }
-  }
-  return QualType();
+void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) {
+  assert(SanOpts.has(SanitizerKind::AllocToken) &&
+         "Only needed with -fsanitize=alloc-token");
+  CB->setMetadata(llvm::LLVMContext::MD_alloc_token,
+                  buildAllocToken(AllocType));
 }
 
-/// Deduces the allocated type by checking if the allocation call's result
-/// is immediately used in a cast expression. For example:
-///
-///   MyType *x = (MyType *)malloc(4096);  // infers 'MyType'
-///
-QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
-                                       const CastExpr *CastE) {
-  if (!CastE)
-    return QualType();
-  QualType PtrType = CastE->getType();
-  if (PtrType->isPointerType())
-    return PtrType->getPointeeType();
-  return QualType();
+llvm::MDNode *CodeGenFunction::buildAllocToken(const CallExpr *E) {
+  QualType AllocType = infer_alloc::inferPossibleType(E, getContext(), CurCast);
+  if (!AllocType.isNull())
+    return buildAllocToken(AllocType);
+  return nullptr;
 }
-} // end anonymous namespace
 
 void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, const CallExpr *E) {
-  QualType AllocType;
-  // First check arguments.
-  for (const Expr *Arg : E->arguments()) {
-    AllocType = inferPossibleTypeFromArithSizeofExpr(Arg);
-    if (AllocType.isNull())
-      AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg);
-    if (!AllocType.isNull())
-      break;
-  }
-  // Then check later casts.
-  if (AllocType.isNull())
-    AllocType = inferPossibleTypeFromCastExpr(E, CurCast);
-  // Emit if we were able to infer the type.
-  if (!AllocType.isNull())
-    EmitAllocToken(CB, AllocType);
+  assert(SanOpts.has(SanitizerKind::AllocToken) &&
+         "Only needed with -fsanitize=alloc-token");
+  if (llvm::MDNode *MDN = buildAllocToken(E))
+    CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN);
 }
 
 CodeGenFunction::ComplexPairTy CodeGenFunction::
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1f0be2d8756de..8c4c1c8c2dc95 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3352,9 +3352,14 @@ class CodeGenFunction : public CodeGenTypeCache {
   SanitizerAnnotateDebugInfo(ArrayRef<SanitizerKind::SanitizerOrdinal> Ordinals,
                              SanitizerHandler Handler);
 
-  /// Emit additional metadata used by the AllocToken instrumentation.
+  /// Build metadata used by the AllocToken instrumentation.
+  llvm::MDNode *buildAllocToken(QualType AllocType);
+  /// Emit and set additional metadata used by the AllocToken instrumentation.
   void EmitAllocToken(llvm::CallBase *CB, QualType AllocType);
-  /// Emit additional metadata used by the AllocToken instrumentation,
+  /// Build additional metadata used by the AllocToken instrumentation,
+  /// inferring the type from an allocation call expression.
+  llvm::MDNode *buildAllocToken(const CallExpr *E);
+  /// Emit and set additional metadata used by the AllocToken instrumentation,
   /// inferring the type from an allocation call expression.
   void EmitAllocToken(llvm::CallBase *CB, const CallExpr *E);
 
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index 292adce8180bc..9ce1df728336e 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -1833,10 +1833,6 @@ void CompilerInvocationBase::GenerateCodeGenArgs(const CodeGenOptions &Opts,
        serializeSanitizerKinds(Opts.SanitizeAnnotateDebugInfo))
     GenerateArg(Consumer, OPT_fsanitize_annotate_debug_info_EQ, Sanitizer);
 
-  if (Opts.AllocTokenMax)
-    GenerateArg(Consumer, OPT_falloc_token_max_EQ,
-                std::to_string(*Opts.AllocTokenMax));
-
   if (!Opts.EmitVersionIdentMetadata)
     GenerateArg(Consumer, OPT_Qn);
 
@@ -2350,15 +2346,6 @@ bool CompilerInvocation::ParseCodeGenArgs(CodeGenOptions &Opts, ArgList &Args,
     }
   }
 
-  if (const auto *Arg = Args.getLastArg(options::OPT_falloc_token_max_EQ)) {
-    StringRef S = Arg->getValue();
-    uint64_t Value = 0;
-    if (S.getAsInteger(0, Value))
-      Diags.Report(diag::err_drv_invalid_value) << Arg->getAsString(Args) << S;
-    else
-      Opts.AllocTokenMax = Value;
-  }
-
   Opts.EmitVersionIdentMetadata = Args.hasFlag(OPT_Qy, OPT_Qn, true);
 
   if (!LangOpts->CUDAIsDevice)
@@ -3966,6 +3953,29 @@ void CompilerInvocationBase::GenerateLangArgs(const LangOptions &Opts,
 
   if (!Opts.RandstructSeed.empty())
     GenerateArg(Consumer, OPT_frandomize_layout_seed_EQ, Opts.RandstructSeed);
+
+  if (Opts.AllocTokenMax)
+    GenerateArg(Consumer, OPT_falloc_token_max_EQ,
+                std::to_string(*Opts.AllocTokenMax));
+
+  if (Opts.AllocTokenMode) {
+    StringRef S;
+    switch (*Opts.AllocTokenMode) {
+    case llvm::AllocTokenMode::Increment:
+      S = "increment";
+      break;
+    case llvm::AllocTokenMode::Random:
+      S = "random";
+      break;
+    case llvm::AllocTokenMode::TypeHash:
+      S = "typehash";
+      break;
+    case llvm::AllocTokenMode::TypeHashPointerSplit:
+      S = "typehashpointersplit";
+      break;
+    }
+    GenerateArg(Consumer, OPT_falloc_token_mode_EQ, S);
+  }
 }
 
 bool CompilerInvocation::ParseLangArgs(LangOptions &Opts, ArgList &Args,
@@ -4544,6 +4554,30 @@ bool CompilerInvocation::ParseLangArgs(LangOptions &Opts, ArgList &Args,
   if (const Arg *A = Args.getLastArg(OPT_frandomize_layout_seed_EQ))
     Opts.RandstructSeed = A->getValue(0);
 
+  if (const auto *Arg = Args.getLastArg(options::OPT_falloc_token_max_EQ)) {
+    StringRef S = Arg->getValue();
+    uint64_t Value = 0;
+    if (S.getAsInteger(0, Value))
+      Diags.Report(diag::err_drv_invalid_value) << Arg->getAsString(Args) << S;
+    else
+      Opts.AllocTokenMax = Value;
+  }
+
+  if (const auto *Arg = Args.getLastArg(options::OPT_falloc_token_mode_EQ)) {
+    StringRef S = Arg->getValue();
+    auto Mode = llvm::StringSwitch<std::optional<llvm::AllocTokenMode>>(S)
+                    .Case("increment", llvm::AllocTokenMode::Increment)
+                    .Case("random", llvm::AllocTokenMode::Random)
+                    .Case("typehash", llvm::AllocTokenMode::TypeHash)
+                    .Case("typehashpointersplit",
+                          llvm::AllocTokenMode::TypeHashPointerSplit)
+                    .Default(std::nullopt);
+    if (Mode)
+      Opts.AllocTokenMode = Mode;
+    else
+      Diags.Report(diag::err_drv_invalid_value) << Arg->getAsString(Args) << S;
+  }
+
   // Validate options for HLSL
   if (Opts.HLSL) {
     // TODO: Revisit restricting SPIR-V to logical once we've figured out how to
diff --git a/clang/test/Driver/fsanitize-alloc-token.c b/clang/test/Driver/fsanitize-alloc-token.c
index 2964f60c4f26f..6d8bda16dfb96 100644
--- a/clang/test/Driver/fsanitize-alloc-token.c
+++ b/clang/test/Driver/fsanitize-alloc-token.c
@@ -41,3 +41,14 @@
 // CHECK-MAX: "-falloc-token-max=42"
 // RUN: not %clang --target=x86_64-linux-gnu -fsanitize=alloc-token -falloc-token-max=-1 %s 2>&1 | FileCheck -check-prefix=CHECK-INVALID-MAX %s
 // CHECK-INVALID-MAX: error: invalid value
+
+// RUN: %clang --target=x86_64-linux-gnu -Xclang -falloc-token-mode=increment %s -### 2>&1 | FileCheck -check-prefix=CHECK-MODE-INCREMENT %s
+// CHECK-MODE-INCREMENT: "-falloc-token-mode=increment"
+// RUN: %clang --target=x86_64-linux-gnu -Xclang -falloc-token-mode=random %s -### 2>&1 | FileCheck -check-prefix=CHECK-MODE-RANDOM %s
+// CHECK-MODE-RANDOM: "-falloc-token-mode=random"
+// RUN: %clang --target=x86_64-linux-gnu -Xclang -falloc-token-mode=typehash %s -### 2>&1 | FileCheck -check-prefix=CHECK-MODE-TYPEHASH %s
+// CHECK-MODE-TYPEHASH: "-falloc-token-mode=typehash"
+// RUN: %clang --target=x86_64-linux-gnu -Xclang -falloc-token-mode=typehashpointersplit %s -### 2>&1 | FileCheck -check-prefix=CHECK-MODE-TYPEHASHPTRSPLIT %s
+// CHECK-MODE-TYPEHASHPTRSPLIT: "-falloc-token-mode=typehashpointersplit"
+// RUN: not %clang --target=x86_64-linux-gnu -Xclang -falloc-token-mode=asdf %s 2>&1 | FileCheck -check-prefix=CHECK-INVALID-MODE %s
+// CHECK-INVALID-MODE: error: invalid value 'asdf'
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8856eda250ed6..0c13b059c4cd0 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2853,7 +2853,15 @@ def int_ptrauth_blend :
 def int_ptrauth_sign_generic :
   DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i64_ty], [IntrNoMem]>;
 
+//===----------------- AllocToken Intrinsics ------------------------------===//
+
+// Return the token ID for the given !alloc_token metadata.
+def int_alloc_token_id :
+  DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_metadata_ty],
+                        [IntrNoMem, NoUndef<RetIndex>]>;
+
 //===----------------------------------------------------------------------===//
+
 //===------- Convergence Intrinsics ---------------------------------------===//
 
 def int_experimental_convergence_entry
diff --git a/llvm/include/llvm/Support/AllocToken.h b/llvm/include/llvm/Support/AllocToken.h
new file mode 100644
index 0000000000000..48db026957443
--- /dev/null
+++ b/llvm/include/llvm/Support/AllocToken.h
@@ -0,0 +1,62 @@
+//===- llvm/Support/AllocToken.h - Allocation Token Calculation -----*- C++ -*//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Definition of AllocToken modes and shared calculation of stateless token IDs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_ALLOCTOKEN_H
+#define LLVM_SUPPORT_ALLOCTOKEN_H
+
+#include "llvm/ADT/SmallString.h"
+#include <cstdint>
+#include <optional>
+
+namespace llvm {
+
+/// Modes for generating allocation token IDs.
+enum class AllocTokenMode {
+  /// Incrementally increasing token ID.
+  Increment,
+
+  /// Simple mode that returns a statically-assigned random token ID.
+  Random,
+
+  /// Token ID based on allocated type hash.
+  TypeHash,
+
+  /// Token ID based on allocated type hash, where the top half ID-space is
+  /// reserved for types that contain pointers and the bottom half for types
+  /// that do not contain pointers.
+  TypeHashPointerSplit,
+};
+
+/// The default allocation token mode.
+inline constexpr AllocTokenMode DefaultAllocTokenMode =
+    AllocTokenMode::TypeHashPointerSplit;
+
+/// Metadata about an allocation used to generate a token ID.
+struct AllocTokenMetadata {
+  SmallString<64> TypeName;
+  bool ContainsPointer;
+};
+
+/// Calculates stable allocation token ID. Returns std::nullopt for stateful
+/// modes that are only available in the AllocToken pass.
+///
+/// \param Mode The token generation mode.
+/// \param Metadata The metadata about the allocation.
+/// \param MaxTokens The maximum number of tokens (must not be 0)
+/// \return The calculated allocation token ID, or std::nullopt.
+std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
+                                          const AllocTokenMetadata &Metadata,
+                                          uint64_t MaxTokens);
+
+} // end namespace llvm
+
+#endif // LLVM_SUPPORT_ALLOCTOKEN_H
diff --git a/llvm/include/llvm/Transforms/Instrumentation/AllocToken.h b/llvm/include/llvm/Transforms/Instrumentation/AllocToken.h
index b1391cb04302c..077703c214745 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/AllocToken.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/AllocToken.h
@@ -16,6 +16,7 @@
 
 #include "llvm/IR/Analysis.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Support/AllocToken.h"
 #include <optional>
 
 namespace llvm {
@@ -23,6 +24,7 @@ namespace llvm {
 class Module;
 
 struct AllocTokenOptions {
+  AllocTokenMode Mode = DefaultAllocTokenMode;
   std::optional<uint64_t> MaxTokens;
   bool FastABI = false;
   bool Extended = false;
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 53cf0046bd858..c3522a38eb9e1 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -1095,6 +1095,38 @@ Expected<MemorySanitizerOptions> parseMSanPassOptions(StringRef Params) {
   return Result;
 }
 
+Expected<AllocTokenOptions> parseAllocTokenPassOptions(StringRef Params) {
+  AllocTokenOptions Result;
+  while (!Params.empty()) {
+    StringRef ParamName;
+    std::tie(ParamName, Params) = Params.split(';');
+
+    if (ParamName.consume_front("mode=")) {
+      auto Mode = StringSwitch<std::optional<AllocTokenMode>>(ParamName)
+                      .Case("increment", AllocTokenMode::Increment)
+                      .Case("random", AllocTokenMode::Random)
+                      .Case("typehash", AllocTokenMode::TypeHash)
+                      .Case("typehashpointersplit",
+                            AllocTokenMode::TypeHashPointerSplit)
+                      .Default(std::nullopt);
+      if (Mode)
+        Result.Mode = *Mode;
+      else
+        return make_error<StringError>(
+            formatv("invalid argument to AllocToken pass mode "
+                    "parameter: '{}'",
+                    ParamName)
+                .str(),
+            inconvertibleErrorCode());
+    } else {
+      return make_error<StringError>(
+          formatv("invalid AllocToken pass parameter '{}'", ParamName).str(),
+          inconvertibleErrorCode());
+    }
+  }
+  return Result;
+}
+
 /// Parser of parameters for SimplifyCFG pass.
 Expected<SimplifyCFGOptions> parseSimplifyCFGOptions(StringRef Params) {
   SimplifyCFGOptions Result;
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 1b1652555cd28..fdd99a43e360d 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -125,7 +125,6 @@ MODULE_PASS("openmp-opt", OpenMPOptPass())
 MODULE_PASS("openmp-opt-postlink",
             OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink))
 MODULE_PASS("partial-inliner", PartialInlinerPass())
-MODULE_PASS("alloc-token", AllocTokenPass())
 MODULE_PASS("pgo-icall-prom", PGOIndirectCallPromotion())
 MODULE_PASS("pgo-instr-gen", PGOInstrumentationGen())
 MODULE_PASS("pgo-instr-use", PGOInstrumentationUse())
@@ -181,6 +180,10 @@ MODULE_PASS("wholeprogramdevirt", WholeProgramDevirtPass())
 #ifndef MODULE_PASS_WITH_PARAMS
 #define MODULE_PASS_WITH_PARAMS(NAME, CLASS, CREATE_PASS, PARSER, PARAMS)
 #endif
+MODULE_PASS_WITH_PARAMS(
+    "alloc-token", "AllocTokenPass",
+    [](AllocTokenOptions Opts) { return AllocTokenPass(Opts); },
+    parseAllocTokenPassOptions, "mode=<mode>")
 MODULE_PASS_WITH_PARAMS(
     "asan", "AddressSanitizerPass",
     [](AddressSanitizerOptions Opts) { return AddressSanitizerPass(Opts); },
diff --git a/llvm/lib/Support/AllocToken.cpp b/llvm/lib/Support/AllocToken.cpp
new file mode 100644
index 0000000000000..6c6f80ac4997c
--- /dev/null
+++ b/llvm/lib/Support/AllocToken.cpp
@@ -0,0 +1,46 @@
+//===- AllocToken.cpp - Allocation Token Calculation ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Definition of AllocToken modes and shared calculation of stateless token IDs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/AllocToken.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/SipHash.h"
+
+namespace llvm {
+std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
+                                          const AllocTokenMetadata &Metadata,
+                                          uint64_t MaxTokens) {
+  assert(MaxTokens && "Must provide concrete max tokens");
+
+  switch (Mode) {
+  case AllocTokenMode::Increment:
+  case AllocTokenMode::Random:
+    // Stateful modes cannot be implemented as a pure function.
+    return std::nullopt;
+
+  case AllocTokenMode::TypeHash: {
+    return getStableSipHash(Metadata.TypeName) % MaxTokens;
+  }
+
+  case AllocTokenMode::TypeHashPointerSplit: {
+    if (MaxTokens == 1)
+      return 0;
+    const uint64_t HalfTokens = MaxTokens / 2;
+    uint64_t Hash = getStableSipHash(Metadata.TypeName) % HalfTokens;
+    if (Metadata.ContainsPointer)
+      Hash += HalfTokens;
+    return Hash;
+  }
+  }
+
+  llvm_unreachable("");
+}
+} // namespace llvm
diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index 42b21b5e62029..671a5fe941cef 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -149,6 +149,7 @@ add_llvm_component_library(LLVMSupport
   AArch64BuildAttributes.cpp
   ARMAttributeParser.cpp
   ARMWinEH.cpp
+  AllocToken.cpp
   Allocator.cpp
   AutoConvert.cpp
   Base64.cpp
diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
index 40720ae4b39ae..bfda56b1f746d 100644
--- a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
@@ -31,10 +31,12 @@
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/AllocToken.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
@@ -53,47 +55,14 @@
 #include <variant>
 
 using namespace llvm;
+using TokenMode = AllocTokenMode;
 
 #define DEBUG_TYPE "alloc-token"
 
 namespace {
 
-//===--- Constants --------------------------------------------------------===//
-
-enum class TokenMode : unsigned {
-  /// Incrementally increasing token ID.
-  Increment = 0,
-
-  /// Simple mode that returns a statically-assigned random token ID.
-  Random = 1,
-
-  /// Token ID based on allocated type hash.
-  TypeHash = 2,
-
-  /// Token ID based on allocated type hash, where the top half ID-space is
-  /// reserved for types that contain pointers and the bottom half for types
-  /// that do not contain pointers.
-  TypeHashPointerSplit = 3,
-};
-
 //===--- Command-line options ---------------------------------------------===//
 
-cl::opt<TokenMode> ClMode(
-    "alloc-token-mode", cl::Hidden, cl::desc("Token assignment mode"),
-    cl::init(TokenMode::TypeHashPointerSplit),
-    cl::values(
-        clEnumValN(TokenMode::Increment, "increment",
-                   "Incrementally increasing token ID"),
-        clEnumValN(TokenMode::Random, "random",
-                   "Statically-assigned random token ID"),
-        clEnumValN(TokenMode::TypeHash, "typehash",
-                   "Token ID based on allocated type hash"),
-        clEnumValN(
-            TokenMode::TypeHashPointerSplit, "typehashpointersplit",
-            "Token ID based on allocated type hash, where the top half "
-            "ID-space is reserved for types that contain pointers and the "
-            "bottom half for types that do not contain pointers. ")));
-
 cl::opt<std::string> ClFuncPrefix("alloc-token-prefix",
                                   cl::desc("The allocation function prefix"),
                                   cl::Hidden, cl::init("__alloc_token_"));
@@ -131,7 +100,7 @@ cl::opt<uint64_t> ClFallbackToken(
 
 //===--- Statistics -------------------------------------------------------===//
 
-STATISTIC(NumFunctionsInstrumented, "Functions instrumented");
+STATISTIC(NumFunctionsModified, "Functions modified");
 STATISTIC(NumAllocationsInstrumented, "Allocations instrumented");
 
 //===----------------------------------------------------------------------===//
@@ -140,9 +109,19 @@ STATISTIC(NumAllocationsInstrumented, "Allocations instrumented");
 ///
 /// Expected format is: !{<type-name>, <contains-pointer>}
 MDNode *getAllocTokenMetadata(const CallBase &CB) {
-  MDNode *Ret = CB.getMetadata(LLVMContext::MD_alloc_token);
-  if (!Ret)
-    return nullptr;
+  MDNode *Ret = nullptr;
+  if (auto *II = dyn_cast<IntrinsicInst>(&CB);
+      II && II->getIntrinsicID() == Intrinsic::alloc_token_id) {
+    auto *MDV = cast<MetadataAsValue>(II->getArgOperand(0));
+    Ret = cast<MDNode>(MDV->getMetadata());
+    // If the intrinsic has an empty MDNode, type inference failed.
+    if (Ret->getNumOperands() == 0)
+      return nullptr;
+  } else {
+    Ret = CB.getMetadata(LLVMContext::MD_alloc_token);
+    if (!Ret)
+      return nullptr;
+  }
   assert(Ret->getNumOperands() == 2 && "bad !alloc_token");
   assert(isa<MDString>(Ret->getOperand(0)));
   assert(isa<ConstantAsMetadata>(Ret->getOperand(1)));
@@ -206,22 +185,20 @@ class TypeHashMode : public ModeBase {
   using ModeBase::ModeBase;
 
   uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
-    const auto [N, H] = getHash(CB, ORE);
-    return N ? boundedToken(H) : H;
-  }
 
-protected:
-  std::pair<MDNode *, uint64_t> getHash(const CallBase &CB,
-                                        OptimizationRemarkEmitter &ORE) {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
-      return {N, getStableSipHash(S->getString())};
+      AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
+      if (auto Token =
+              getAllocTokenHash(TokenMode::TypeHash, Metadata, MaxTokens))
+        return *Token;
     }
     // Fallback.
     remarkNoMetadata(CB, ORE);
-    return {nullptr, ClFallbackToken};
+    return ClFallbackToken;
   }
 
+protected:
   /// Remark that there was no precise type information.
   static void remarkNoMetadata(const CallBase &CB,
                                OptimizationRemarkEmitter &ORE) {
@@ -242,20 +219,18 @@ class TypeHashPointerSplitMode : public TypeHashMode {
   using TypeHashMode::TypeHashMode;
 
   uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
-    if (MaxTokens == 1)
-      return 0;
-    const uint64_t HalfTokens = MaxTokens / 2;
-    const auto [N, H] = getHash(CB, ORE);
-    if (!N) {
-      // Pick the fallback token (ClFallbackToken), which by default is 0,
-      // meaning it'll fall into the pointer-less bucket. Override by setting
-      // -alloc-token-fallback if that is the wrong choice.
-      return H;
+    if (MDNode *N = getAllocTokenMetadata(CB)) {
+      MDString *S = cast<MDString>(N->getOperand(0));
+      AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
+      if (auto Token = getAllocTokenHash(TokenMode::TypeHashPointerSplit,
+                                         Metadata, MaxTokens))
+        return *Token;
     }
-    uint64_t Hash = H % HalfTokens; // base hash
-    if (containsPointer(N))
-      Hash += HalfTokens;
-    return Hash;
+    // Pick the fallback token (ClFallbackToken), which by default is 0, meaning
+    // it'll fall into the pointer-less bucket. Override by setting
+    // -alloc-token-fallback if that is the wrong choice.
+    remarkNoMetadata(CB, ORE);
+    return ClFallbackToken;
   }
 };
 
@@ -275,7 +250,7 @@ class AllocToken {
       : Options(transformOptionsFromCl(std::move(Opts))), Mod(M),
         FAM(MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
         Mode(IncrementMode(*IntPtrTy, *Options.MaxTokens)) {
-    switch (ClMode.getValue()) {
+    switch (Options.Mode) {
     case TokenMode::Increment:
       break;
     case TokenMode::Random:
@@ -315,6 +290,9 @@ class AllocToken {
   FunctionCallee getTokenAllocFunction(const CallBase &CB, uint64_t TokenID,
                                        LibFunc OriginalFunc);
 
+  /// Lower alloc_token_* intrinsics.
+  void replaceIntrinsicInst(IntrinsicInst *II, OptimizationRemarkEmitter &ORE);
+
   /// Return the token ID from metadata in the call.
   uint64_t getToken(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
     return std::visit([&](auto &&Mode) { return Mode(CB, ORE); }, Mode);
@@ -336,21 +314,32 @@ bool AllocToken::instrumentFunction(Function &F) {
   // Do not apply any instrumentation for naked functions.
   if (F.hasFnAttribute(Attribute::Naked))
     return false;
-  if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation))
-    return false;
   // Don't touch available_externally functions, their actual body is elsewhere.
   if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage)
     return false;
-  // Only instrument functions that have the sanitize_alloc_token attribute.
-  if (!F.hasFnAttribute(Attribute::SanitizeAllocToken))
-    return false;
 
   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
   auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
   SmallVector<std::pair<CallBase *, LibFunc>, 4> AllocCalls;
+  SmallVector<IntrinsicInst *, 4> IntrinsicInsts;
+
+  // Only instrument functions that have the sanitize_alloc_token attribute.
+  const bool InstrumentFunction =
+      F.hasFnAttribute(Attribute::SanitizeAllocToken) &&
+      !F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation);
 
   // Collect all allocation calls to avoid iterator invalidation.
   for (Instruction &I : instructions(F)) {
+    // Collect all alloc_token_* intrinsics.
+    if (auto *II = dyn_cast<IntrinsicInst>(&I);
+        II && II->getIntrinsicID() == Intrinsic::alloc_token_id) {
+      IntrinsicInsts.emplace_back(II);
+      continue;
+    }
+
+    if (!InstrumentFunction)
+      continue;
+
     auto *CB = dyn_cast<CallBase>(&I);
     if (!CB)
       continue;
@@ -359,11 +348,22 @@ bool AllocToken::instrumentFunction(Function &F) {
   }
 
   bool Modified = false;
-  for (auto &[CB, Func] : AllocCalls)
-    Modified |= replaceAllocationCall(CB, Func, ORE, TLI);
 
-  if (Modified)
-    NumFunctionsInstrumented++;
+  if (!AllocCalls.empty()) {
+    for (auto &[CB, Func] : AllocCalls)
+      Modified |= replaceAllocationCall(CB, Func, ORE, TLI);
+    if (Modified)
+      NumFunctionsModified++;
+  }
+
+  if (!IntrinsicInsts.empty()) {
+    for (auto *II : IntrinsicInsts) {
+      replaceIntrinsicInst(II, ORE);
+    }
+    Modified = true;
+    NumFunctionsModified++;
+  }
+
   return Modified;
 }
 
@@ -528,6 +528,16 @@ FunctionCallee AllocToken::getTokenAllocFunction(const CallBase &CB,
   return TokenAlloc;
 }
 
+void AllocToken::replaceIntrinsicInst(IntrinsicInst *II,
+                                      OptimizationRemarkEmitter &ORE) {
+  assert(II->getIntrinsicID() == Intrinsic::alloc_token_id);
+
+  uint64_t TokenID = getToken(*II, ORE);
+  Value *V = ConstantInt::get(IntPtrTy, TokenID);
+  II->replaceAllUsesWith(V);
+  II->eraseFromParent();
+}
+
 } // namespace
 
 AllocTokenPass::AllocTokenPass(AllocTokenOptions Opts)
diff --git a/llvm/test/Instrumentation/AllocToken/basic.ll b/llvm/test/Instrumentation/AllocToken/basic.ll
index 099d37df264d6..0c34b1373cfa1 100644
--- a/llvm/test/Instrumentation/AllocToken/basic.ll
+++ b/llvm/test/Instrumentation/AllocToken/basic.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=inferattrs,alloc-token -alloc-token-mode=increment -S | FileCheck %s
+; RUN: opt < %s -passes='inferattrs,alloc-token<mode=increment>' -S | FileCheck %s
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 
diff --git a/llvm/test/Instrumentation/AllocToken/basic32.ll b/llvm/test/Instrumentation/AllocToken/basic32.ll
index 944a452f4b4d7..52d1d1446f0ef 100644
--- a/llvm/test/Instrumentation/AllocToken/basic32.ll
+++ b/llvm/test/Instrumentation/AllocToken/basic32.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=inferattrs,alloc-token -alloc-token-mode=increment -S | FileCheck %s
+; RUN: opt < %s -passes='inferattrs,alloc-token<mode=increment>' -S | FileCheck %s
 
 target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:32:32-n8:16:32-S128"
 
diff --git a/llvm/test/Instrumentation/AllocToken/fast.ll b/llvm/test/Instrumentation/AllocToken/fast.ll
index 19a3ef6bb9ede..f6bf5ee0b7c97 100644
--- a/llvm/test/Instrumentation/AllocToken/fast.ll
+++ b/llvm/test/Instrumentation/AllocToken/fast.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=inferattrs,alloc-token -alloc-token-mode=increment -alloc-token-fast-abi -alloc-token-max=3 -S | FileCheck %s
+; RUN: opt < %s -passes='inferattrs,alloc-token<mode=increment>' -alloc-token-fast-abi -alloc-token-max=3 -S | FileCheck %s
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 
diff --git a/llvm/test/Instrumentation/AllocToken/intrinsic.ll b/llvm/test/Instrumentation/AllocToken/intrinsic.ll
new file mode 100644
index 0000000000000..5c6f2f147b3d0
--- /dev/null
+++ b/llvm/test/Instrumentation/AllocToken/intrinsic.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; Test that the alloc-token pass lowers the intrinsic to a constant token ID.
+;
+; RUN: opt < %s -passes='alloc-token<mode=typehashpointersplit>' -alloc-token-max=2 -S | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+declare i64 @llvm.alloc.token.id.i64(metadata)
+
+define i64 @test_intrinsic_lowering() {
+; CHECK-LABEL: define i64 @test_intrinsic_lowering() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i64 0
+;
+entry:
+  %token_no_ptr = call i64 @llvm.alloc.token.id.i64(metadata !0)
+  ret i64 %token_no_ptr
+}
+
+define i64 @test_intrinsic_lowering_ptr() {
+; CHECK-LABEL: define i64 @test_intrinsic_lowering_ptr() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i64 1
+;
+entry:
+  %token_with_ptr = call i64 @llvm.alloc.token.id.i64(metadata !1)
+  ret i64 %token_with_ptr
+}
+
+!0 = !{!"NoPointerType", i1 false}
+!1 = !{!"PointerType", i1 true}
diff --git a/llvm/test/Instrumentation/AllocToken/intrinsic32.ll b/llvm/test/Instrumentation/AllocToken/intrinsic32.ll
new file mode 100644
index 0000000000000..15f7c258e2a5d
--- /dev/null
+++ b/llvm/test/Instrumentation/AllocToken/intrinsic32.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; Test that the alloc-token pass lowers the intrinsic to a constant token ID.
+;
+; RUN: opt < %s -passes='alloc-token<mode=typehashpointersplit>' -alloc-token-max=2 -S | FileCheck %s
+
+target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:32:32-n8:16:32-S128"
+target triple = "i386-pc-linux-gnu"
+
+declare i32 @llvm.alloc.token.id.i32(metadata)
+
+define i32 @test_intrinsic_lowering() {
+; CHECK-LABEL: define i32 @test_intrinsic_lowering() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %token_no_ptr = call i32 @llvm.alloc.token.id.i32(metadata !0)
+  ret i32 %token_no_ptr
+}
+
+define i32 @test_intrinsic_lowering_ptr() {
+; CHECK-LABEL: define i32 @test_intrinsic_lowering_ptr() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i32 1
+;
+entry:
+  %token_with_ptr = call i32 @llvm.alloc.token.id.i32(metadata !1)
+  ret i32 %token_with_ptr
+}
+
+!0 = !{!"NoPointerType", i1 false}
+!1 = !{!"PointerType", i1 true}
diff --git a/llvm/test/Instrumentation/AllocToken/invoke.ll b/llvm/test/Instrumentation/AllocToken/invoke.ll
index 347c99a2e8f8d..8e7ab3848dc05 100644
--- a/llvm/test/Instrumentation/AllocToken/invoke.ll
+++ b/llvm/test/Instrumentation/AllocToken/invoke.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=inferattrs,alloc-token -alloc-token-mode=increment -S | FileCheck %s
+; RUN: opt < %s -passes='inferattrs,alloc-token<mode=increment>' -S | FileCheck %s
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 
diff --git a/llvm/test/Instrumentation/AllocToken/nonlibcalls.ll b/llvm/test/Instrumentation/AllocToken/nonlibcalls.ll
index 19673da1bcfb6..45f573ee7b044 100644
--- a/llvm/test/Instrumentation/AllocToken/nonlibcalls.ll
+++ b/llvm/test/Instrumentation/AllocToken/nonlibcalls.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=inferattrs,alloc-token -alloc-token-mode=increment -alloc-token-extended -S | FileCheck %s
+; RUN: opt < %s -passes='inferattrs,alloc-token<mode=increment>' -alloc-token-extended -S | FileCheck %s
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 
diff --git a/llvm/test/Instrumentation/AllocToken/typehashpointersplit.ll b/llvm/test/Instrumentation/AllocToken/typehashpointersplit.ll
index 1f776480c5b3a..4d1be5eac8cd2 100644
--- a/llvm/test/Instrumentation/AllocToken/typehashpointersplit.ll
+++ b/llvm/test/Instrumentation/AllocToken/typehashpointersplit.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt < %s -passes=inferattrs,alloc-token -alloc-token-mode=typehashpointersplit -alloc-token-max=2 -S | FileCheck %s
+; RUN: opt < %s -passes='inferattrs,alloc-token<mode=typehashpointersplit>' -alloc-token-max=2 -S | FileCheck %s
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 
diff --git a/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn b/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
index 9981d100fd555..4da907cbdd938 100644
--- a/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
+++ b/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
@@ -121,6 +121,7 @@ static_library("AST") {
     "ExternalASTMerger.cpp",
     "ExternalASTSource.cpp",
     "FormatString.cpp",
+    "InferAlloc.cpp",
     "InheritViz.cpp",
     "ItaniumCXXABI.cpp",
     "ItaniumMangle.cpp",
diff --git a/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn b/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
index 38ba4661daacc..df9ddf91f2c49 100644
--- a/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
@@ -45,6 +45,7 @@ static_library("Support") {
     "ARMAttributeParser.cpp",
     "ARMBuildAttributes.cpp",
     "ARMWinEH.cpp",
+    "AllocToken.cpp",
     "Allocator.cpp",
     "AutoConvert.cpp",
     "BalancedPartitioning.cpp",
>From 99e17cc4cea8f5be8b937cca380ad5982db0ef12 Mon Sep 17 00:00:00 2001
From: Marco Elver <elver at google.com>
Date: Fri, 17 Oct 2025 20:14:39 +0200
Subject: [PATCH 2/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20introduced=20through=20rebase?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.8-beta.1
[skip ci]
---
 clang/lib/AST/InferAlloc.cpp                  | 35 +++++++++----------
 clang/lib/Frontend/CompilerInvocation.cpp     |  9 +----
 llvm/include/llvm/Support/AllocToken.h        | 12 +++++--
 llvm/lib/Passes/PassBuilder.cpp               |  9 +----
 llvm/lib/Support/AllocToken.cpp               | 35 +++++++++++++------
 .../Transforms/Instrumentation/AllocToken.cpp | 12 +++----
 6 files changed, 57 insertions(+), 55 deletions(-)
diff --git a/clang/lib/AST/InferAlloc.cpp b/clang/lib/AST/InferAlloc.cpp
index c21fcfccaef0f..3ec55c26ac366 100644
--- a/clang/lib/AST/InferAlloc.cpp
+++ b/clang/lib/AST/InferAlloc.cpp
@@ -19,11 +19,13 @@
 #include "clang/Basic/IdentifierTable.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
-namespace clang {
-namespace {
-bool typeContainsPointer(QualType T,
-                         llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
-                         bool &IncompleteType) {
+using namespace clang;
+using namespace infer_alloc;
+
+static bool
+typeContainsPointer(QualType T,
+                    llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
+                    bool &IncompleteType) {
   QualType CanonicalType = T.getCanonicalType();
   if (CanonicalType->isPointerType())
     return true; // base case
@@ -70,7 +72,7 @@ bool typeContainsPointer(QualType T,
 }
 
 /// Infer type from a simple sizeof expression.
-QualType inferTypeFromSizeofExpr(const Expr *E) {
+static QualType inferTypeFromSizeofExpr(const Expr *E) {
   const Expr *Arg = E->IgnoreParenImpCasts();
   if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
     if (UET->getKind() == UETT_SizeOf) {
@@ -96,7 +98,7 @@ QualType inferTypeFromSizeofExpr(const Expr *E) {
 ///
 ///   malloc(sizeof(HasFlexArray) + sizeof(int) * 32);  // infers 'HasFlexArray'
 ///
-QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
+static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
   const Expr *Arg = E->IgnoreParenImpCasts();
   // The argument is a lone sizeof expression.
   if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
@@ -132,7 +134,7 @@ QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
 ///   size_t my_size = sizeof(MyType);
 ///   void *x = malloc(my_size);  // infers 'MyType'
 ///
-QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
+static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
   const Expr *Arg = E->IgnoreParenImpCasts();
   if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
     if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
@@ -148,8 +150,8 @@ QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
 ///
 ///   MyType *x = (MyType *)malloc(4096);  // infers 'MyType'
 ///
-QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
-                                       const CastExpr *CastE) {
+static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
+                                              const CastExpr *CastE) {
   if (!CastE)
     return QualType();
   QualType PtrType = CastE->getType();
@@ -157,12 +159,10 @@ QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
     return PtrType->getPointeeType();
   return QualType();
 }
-} // anonymous namespace
-
-namespace infer_alloc {
 
-QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
-                           const CastExpr *CastE) {
+QualType clang::infer_alloc::inferPossibleType(const CallExpr *E,
+                                               const ASTContext &Ctx,
+                                               const CastExpr *CastE) {
   QualType AllocType;
   // First check arguments.
   for (const Expr *Arg : E->arguments()) {
@@ -179,7 +179,7 @@ QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
 }
 
 std::optional<llvm::AllocTokenMetadata>
-getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
+clang::infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
   llvm::AllocTokenMetadata ATMD;
 
   // Get unique type name.
@@ -199,6 +199,3 @@ getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
 
   return ATMD;
 }
-
-} // namespace infer_alloc
-} // namespace clang
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index 9ce1df728336e..85ba85099500a 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -4565,14 +4565,7 @@ bool CompilerInvocation::ParseLangArgs(LangOptions &Opts, ArgList &Args,
 
   if (const auto *Arg = Args.getLastArg(options::OPT_falloc_token_mode_EQ)) {
     StringRef S = Arg->getValue();
-    auto Mode = llvm::StringSwitch<std::optional<llvm::AllocTokenMode>>(S)
-                    .Case("increment", llvm::AllocTokenMode::Increment)
-                    .Case("random", llvm::AllocTokenMode::Random)
-                    .Case("typehash", llvm::AllocTokenMode::TypeHash)
-                    .Case("typehashpointersplit",
-                          llvm::AllocTokenMode::TypeHashPointerSplit)
-                    .Default(std::nullopt);
-    if (Mode)
+    if (auto Mode = getAllocTokenModeFromString(S))
       Opts.AllocTokenMode = Mode;
     else
       Diags.Report(diag::err_drv_invalid_value) << Arg->getAsString(Args) << S;
diff --git a/llvm/include/llvm/Support/AllocToken.h b/llvm/include/llvm/Support/AllocToken.h
index 48db026957443..e40d8163a9d7c 100644
--- a/llvm/include/llvm/Support/AllocToken.h
+++ b/llvm/include/llvm/Support/AllocToken.h
@@ -14,6 +14,7 @@
 #define LLVM_SUPPORT_ALLOCTOKEN_H
 
 #include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringRef.h"
 #include <cstdint>
 #include <optional>
 
@@ -40,6 +41,11 @@ enum class AllocTokenMode {
 inline constexpr AllocTokenMode DefaultAllocTokenMode =
     AllocTokenMode::TypeHashPointerSplit;
 
+/// Returns the AllocTokenMode from its canonical string name; if an invalid
+/// name was provided returns nullopt.
+LLVM_ABI std::optional<AllocTokenMode>
+getAllocTokenModeFromString(StringRef Name);
+
 /// Metadata about an allocation used to generate a token ID.
 struct AllocTokenMetadata {
   SmallString<64> TypeName;
@@ -53,9 +59,9 @@ struct AllocTokenMetadata {
 /// \param Metadata The metadata about the allocation.
 /// \param MaxTokens The maximum number of tokens (must not be 0)
 /// \return The calculated allocation token ID, or std::nullopt.
-std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
-                                          const AllocTokenMetadata &Metadata,
-                                          uint64_t MaxTokens);
+LLVM_ABI std::optional<uint64_t>
+getAllocToken(AllocTokenMode Mode, const AllocTokenMetadata &Metadata,
+              uint64_t MaxTokens);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index c3522a38eb9e1..4cebb0bb32e4e 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -1102,14 +1102,7 @@ Expected<AllocTokenOptions> parseAllocTokenPassOptions(StringRef Params) {
     std::tie(ParamName, Params) = Params.split(';');
 
     if (ParamName.consume_front("mode=")) {
-      auto Mode = StringSwitch<std::optional<AllocTokenMode>>(ParamName)
-                      .Case("increment", AllocTokenMode::Increment)
-                      .Case("random", AllocTokenMode::Random)
-                      .Case("typehash", AllocTokenMode::TypeHash)
-                      .Case("typehashpointersplit",
-                            AllocTokenMode::TypeHashPointerSplit)
-                      .Default(std::nullopt);
-      if (Mode)
+      if (auto Mode = getAllocTokenModeFromString(ParamName))
         Result.Mode = *Mode;
       else
         return make_error<StringError>(
diff --git a/llvm/lib/Support/AllocToken.cpp b/llvm/lib/Support/AllocToken.cpp
index 6c6f80ac4997c..8e9e89f0df353 100644
--- a/llvm/lib/Support/AllocToken.cpp
+++ b/llvm/lib/Support/AllocToken.cpp
@@ -11,14 +11,31 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Support/AllocToken.h"
+#include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/SipHash.h"
 
-namespace llvm {
-std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
-                                          const AllocTokenMetadata &Metadata,
-                                          uint64_t MaxTokens) {
-  assert(MaxTokens && "Must provide concrete max tokens");
+using namespace llvm;
+
+std::optional<AllocTokenMode>
+llvm::getAllocTokenModeFromString(StringRef Name) {
+  return StringSwitch<std::optional<AllocTokenMode>>(Name)
+      .Case("increment", AllocTokenMode::Increment)
+      .Case("random", AllocTokenMode::Random)
+      .Case("typehash", AllocTokenMode::TypeHash)
+      .Case("typehashpointersplit", AllocTokenMode::TypeHashPointerSplit)
+      .Default(std::nullopt);
+}
+
+static uint64_t getStableHash(const AllocTokenMetadata &Metadata,
+                              uint64_t MaxTokens) {
+  return getStableSipHash(Metadata.TypeName) % MaxTokens;
+}
+
+std::optional<uint64_t> llvm::getAllocToken(AllocTokenMode Mode,
+                                            const AllocTokenMetadata &Metadata,
+                                            uint64_t MaxTokens) {
+  assert(MaxTokens && "Must provide non-zero max tokens");
 
   switch (Mode) {
   case AllocTokenMode::Increment:
@@ -26,15 +43,14 @@ std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
     // Stateful modes cannot be implemented as a pure function.
     return std::nullopt;
 
-  case AllocTokenMode::TypeHash: {
-    return getStableSipHash(Metadata.TypeName) % MaxTokens;
-  }
+  case AllocTokenMode::TypeHash:
+    return getStableHash(Metadata, MaxTokens);
 
   case AllocTokenMode::TypeHashPointerSplit: {
     if (MaxTokens == 1)
       return 0;
     const uint64_t HalfTokens = MaxTokens / 2;
-    uint64_t Hash = getStableSipHash(Metadata.TypeName) % HalfTokens;
+    uint64_t Hash = getStableHash(Metadata, HalfTokens);
     if (Metadata.ContainsPointer)
       Hash += HalfTokens;
     return Hash;
@@ -43,4 +59,3 @@ std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
 
   llvm_unreachable("");
 }
-} // namespace llvm
diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
index bfda56b1f746d..8181e4ef1d74f 100644
--- a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
@@ -189,8 +189,7 @@ class TypeHashMode : public ModeBase {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
       AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
-      if (auto Token =
-              getAllocTokenHash(TokenMode::TypeHash, Metadata, MaxTokens))
+      if (auto Token = getAllocToken(TokenMode::TypeHash, Metadata, MaxTokens))
         return *Token;
     }
     // Fallback.
@@ -222,8 +221,8 @@ class TypeHashPointerSplitMode : public TypeHashMode {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
       AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
-      if (auto Token = getAllocTokenHash(TokenMode::TypeHashPointerSplit,
-                                         Metadata, MaxTokens))
+      if (auto Token = getAllocToken(TokenMode::TypeHashPointerSplit, Metadata,
+                                     MaxTokens))
         return *Token;
     }
     // Pick the fallback token (ClFallbackToken), which by default is 0, meaning
@@ -357,9 +356,8 @@ bool AllocToken::instrumentFunction(Function &F) {
   }
 
   if (!IntrinsicInsts.empty()) {
-    for (auto *II : IntrinsicInsts) {
+    for (auto *II : IntrinsicInsts)
       replaceIntrinsicInst(II, ORE);
-    }
     Modified = true;
     NumFunctionsModified++;
   }
@@ -381,7 +379,7 @@ AllocToken::shouldInstrumentCall(const CallBase &CB,
   if (TLI.getLibFunc(*Callee, Func)) {
     if (isInstrumentableLibFunc(Func, CB, TLI))
       return Func;
-  } else if (Options.Extended && getAllocTokenMetadata(CB)) {
+  } else if (Options.Extended && CB.getMetadata(LLVMContext::MD_alloc_token)) {
     return NotLibFunc;
   }
 
    
    
More information about the llvm-commits
mailing list