[llvm] [AllocToken] Make token mode a pass parameter (PR #163634)

Marco Elver via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 22 01:30:43 PDT 2025


https://github.com/melver updated https://github.com/llvm/llvm-project/pull/163634

>From 9e0a05dcae9f32c40ebd4f34f10b552438320e46 Mon Sep 17 00:00:00 2001
From: Marco Elver <elver at google.com>
Date: Wed, 15 Oct 2025 23:28:44 +0200
Subject: [PATCH 1/3] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20to=20main=20this=20commit=20is=20based=20on?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.8-beta.1

[skip ci]
---
 llvm/include/llvm/IR/Intrinsics.td            |   8 ++
 llvm/include/llvm/Support/AllocToken.h        |  58 ++++++++
 llvm/lib/Support/AllocToken.cpp               |  46 +++++++
 llvm/lib/Support/CMakeLists.txt               |   1 +
 .../Transforms/Instrumentation/AllocToken.cpp | 130 +++++++++++-------
 .../Instrumentation/AllocToken/intrinsic.ll   |  32 +++++
 .../Instrumentation/AllocToken/intrinsic32.ll |  32 +++++
 7 files changed, 255 insertions(+), 52 deletions(-)
 create mode 100644 llvm/include/llvm/Support/AllocToken.h
 create mode 100644 llvm/lib/Support/AllocToken.cpp
 create mode 100644 llvm/test/Instrumentation/AllocToken/intrinsic.ll
 create mode 100644 llvm/test/Instrumentation/AllocToken/intrinsic32.ll

diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8856eda250ed6..0c13b059c4cd0 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2853,7 +2853,15 @@ def int_ptrauth_blend :
 def int_ptrauth_sign_generic :
   DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i64_ty], [IntrNoMem]>;
 
+//===----------------- AllocToken Intrinsics ------------------------------===//
+
+// Return the token ID for the given !alloc_token metadata.
+def int_alloc_token_id :
+  DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_metadata_ty],
+                        [IntrNoMem, NoUndef<RetIndex>]>;
+
 //===----------------------------------------------------------------------===//
+
 //===------- Convergence Intrinsics ---------------------------------------===//
 
 def int_experimental_convergence_entry
diff --git a/llvm/include/llvm/Support/AllocToken.h b/llvm/include/llvm/Support/AllocToken.h
new file mode 100644
index 0000000000000..6617b7d1f7668
--- /dev/null
+++ b/llvm/include/llvm/Support/AllocToken.h
@@ -0,0 +1,58 @@
+//===- llvm/Support/AllocToken.h - Allocation Token Calculation -----*- C++ -*//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Definition of AllocToken modes and shared calculation of stateless token IDs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_ALLOCTOKEN_H
+#define LLVM_SUPPORT_ALLOCTOKEN_H
+
+#include "llvm/ADT/SmallString.h"
+#include <cstdint>
+#include <optional>
+
+namespace llvm {
+
+/// Modes for generating allocation token IDs.
+enum class AllocTokenMode {
+  /// Incrementally increasing token ID.
+  Increment,
+
+  /// Simple mode that returns a statically-assigned random token ID.
+  Random,
+
+  /// Token ID based on allocated type hash.
+  TypeHash,
+
+  /// Token ID based on allocated type hash, where the top half ID-space is
+  /// reserved for types that contain pointers and the bottom half for types
+  /// that do not contain pointers.
+  TypeHashPointerSplit,
+};
+
+/// Metadata about an allocation used to generate a token ID.
+struct AllocTokenMetadata {
+  SmallString<64> TypeName;
+  bool ContainsPointer;
+};
+
+/// Calculates stable allocation token ID. Returns std::nullopt for stateful
+/// modes that are only available in the AllocToken pass.
+///
+/// \param Mode The token generation mode.
+/// \param Metadata The metadata about the allocation.
+/// \param MaxTokens The maximum number of tokens (must not be 0)
+/// \return The calculated allocation token ID, or std::nullopt.
+std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
+                                          const AllocTokenMetadata &Metadata,
+                                          uint64_t MaxTokens);
+
+} // end namespace llvm
+
+#endif // LLVM_SUPPORT_ALLOCTOKEN_H
diff --git a/llvm/lib/Support/AllocToken.cpp b/llvm/lib/Support/AllocToken.cpp
new file mode 100644
index 0000000000000..6c6f80ac4997c
--- /dev/null
+++ b/llvm/lib/Support/AllocToken.cpp
@@ -0,0 +1,46 @@
+//===- AllocToken.cpp - Allocation Token Calculation ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Definition of AllocToken modes and shared calculation of stateless token IDs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/AllocToken.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/SipHash.h"
+
+namespace llvm {
+std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
+                                          const AllocTokenMetadata &Metadata,
+                                          uint64_t MaxTokens) {
+  assert(MaxTokens && "Must provide concrete max tokens");
+
+  switch (Mode) {
+  case AllocTokenMode::Increment:
+  case AllocTokenMode::Random:
+    // Stateful modes cannot be implemented as a pure function.
+    return std::nullopt;
+
+  case AllocTokenMode::TypeHash: {
+    return getStableSipHash(Metadata.TypeName) % MaxTokens;
+  }
+
+  case AllocTokenMode::TypeHashPointerSplit: {
+    if (MaxTokens == 1)
+      return 0;
+    const uint64_t HalfTokens = MaxTokens / 2;
+    uint64_t Hash = getStableSipHash(Metadata.TypeName) % HalfTokens;
+    if (Metadata.ContainsPointer)
+      Hash += HalfTokens;
+    return Hash;
+  }
+  }
+
+  llvm_unreachable("");
+}
+} // namespace llvm
diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index 42b21b5e62029..f06b460a0113d 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -147,6 +147,7 @@ add_llvm_component_library(LLVMSupport
   ARMBuildAttributes.cpp
   AArch64AttributeParser.cpp
   AArch64BuildAttributes.cpp
+  AllocToken.cpp
   ARMAttributeParser.cpp
   ARMWinEH.cpp
   Allocator.cpp
diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
index 40720ae4b39ae..7c488ec96120e 100644
--- a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
@@ -31,10 +31,12 @@
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/AllocToken.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
@@ -53,29 +55,12 @@
 #include <variant>
 
 using namespace llvm;
+using TokenMode = AllocTokenMode;
 
 #define DEBUG_TYPE "alloc-token"
 
 namespace {
 
-//===--- Constants --------------------------------------------------------===//
-
-enum class TokenMode : unsigned {
-  /// Incrementally increasing token ID.
-  Increment = 0,
-
-  /// Simple mode that returns a statically-assigned random token ID.
-  Random = 1,
-
-  /// Token ID based on allocated type hash.
-  TypeHash = 2,
-
-  /// Token ID based on allocated type hash, where the top half ID-space is
-  /// reserved for types that contain pointers and the bottom half for types
-  /// that do not contain pointers.
-  TypeHashPointerSplit = 3,
-};
-
 //===--- Command-line options ---------------------------------------------===//
 
 cl::opt<TokenMode> ClMode(
@@ -131,7 +116,7 @@ cl::opt<uint64_t> ClFallbackToken(
 
 //===--- Statistics -------------------------------------------------------===//
 
-STATISTIC(NumFunctionsInstrumented, "Functions instrumented");
+STATISTIC(NumFunctionsModified, "Functions modified");
 STATISTIC(NumAllocationsInstrumented, "Allocations instrumented");
 
 //===----------------------------------------------------------------------===//
@@ -140,9 +125,19 @@ STATISTIC(NumAllocationsInstrumented, "Allocations instrumented");
 ///
 /// Expected format is: !{<type-name>, <contains-pointer>}
 MDNode *getAllocTokenMetadata(const CallBase &CB) {
-  MDNode *Ret = CB.getMetadata(LLVMContext::MD_alloc_token);
-  if (!Ret)
-    return nullptr;
+  MDNode *Ret = nullptr;
+  if (auto *II = dyn_cast<IntrinsicInst>(&CB);
+      II && II->getIntrinsicID() == Intrinsic::alloc_token_id) {
+    auto *MDV = cast<MetadataAsValue>(II->getArgOperand(0));
+    Ret = cast<MDNode>(MDV->getMetadata());
+    // If the intrinsic has an empty MDNode, type inference failed.
+    if (Ret->getNumOperands() == 0)
+      return nullptr;
+  } else {
+    Ret = CB.getMetadata(LLVMContext::MD_alloc_token);
+    if (!Ret)
+      return nullptr;
+  }
   assert(Ret->getNumOperands() == 2 && "bad !alloc_token");
   assert(isa<MDString>(Ret->getOperand(0)));
   assert(isa<ConstantAsMetadata>(Ret->getOperand(1)));
@@ -206,22 +201,20 @@ class TypeHashMode : public ModeBase {
   using ModeBase::ModeBase;
 
   uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
-    const auto [N, H] = getHash(CB, ORE);
-    return N ? boundedToken(H) : H;
-  }
 
-protected:
-  std::pair<MDNode *, uint64_t> getHash(const CallBase &CB,
-                                        OptimizationRemarkEmitter &ORE) {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
-      return {N, getStableSipHash(S->getString())};
+      AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
+      if (auto Token =
+              getAllocTokenHash(TokenMode::TypeHash, Metadata, MaxTokens))
+        return *Token;
     }
     // Fallback.
     remarkNoMetadata(CB, ORE);
-    return {nullptr, ClFallbackToken};
+    return ClFallbackToken;
   }
 
+protected:
   /// Remark that there was no precise type information.
   static void remarkNoMetadata(const CallBase &CB,
                                OptimizationRemarkEmitter &ORE) {
@@ -242,20 +235,18 @@ class TypeHashPointerSplitMode : public TypeHashMode {
   using TypeHashMode::TypeHashMode;
 
   uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
-    if (MaxTokens == 1)
-      return 0;
-    const uint64_t HalfTokens = MaxTokens / 2;
-    const auto [N, H] = getHash(CB, ORE);
-    if (!N) {
-      // Pick the fallback token (ClFallbackToken), which by default is 0,
-      // meaning it'll fall into the pointer-less bucket. Override by setting
-      // -alloc-token-fallback if that is the wrong choice.
-      return H;
+    if (MDNode *N = getAllocTokenMetadata(CB)) {
+      MDString *S = cast<MDString>(N->getOperand(0));
+      AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
+      if (auto Token = getAllocTokenHash(TokenMode::TypeHashPointerSplit,
+                                         Metadata, MaxTokens))
+        return *Token;
     }
-    uint64_t Hash = H % HalfTokens; // base hash
-    if (containsPointer(N))
-      Hash += HalfTokens;
-    return Hash;
+    // Pick the fallback token (ClFallbackToken), which by default is 0, meaning
+    // it'll fall into the pointer-less bucket. Override by setting
+    // -alloc-token-fallback if that is the wrong choice.
+    remarkNoMetadata(CB, ORE);
+    return ClFallbackToken;
   }
 };
 
@@ -315,6 +306,9 @@ class AllocToken {
   FunctionCallee getTokenAllocFunction(const CallBase &CB, uint64_t TokenID,
                                        LibFunc OriginalFunc);
 
+  /// Lower alloc_token_* intrinsics.
+  void replaceIntrinsicInst(IntrinsicInst *II, OptimizationRemarkEmitter &ORE);
+
   /// Return the token ID from metadata in the call.
   uint64_t getToken(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
     return std::visit([&](auto &&Mode) { return Mode(CB, ORE); }, Mode);
@@ -336,21 +330,32 @@ bool AllocToken::instrumentFunction(Function &F) {
   // Do not apply any instrumentation for naked functions.
   if (F.hasFnAttribute(Attribute::Naked))
     return false;
-  if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation))
-    return false;
   // Don't touch available_externally functions, their actual body is elsewhere.
   if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage)
     return false;
-  // Only instrument functions that have the sanitize_alloc_token attribute.
-  if (!F.hasFnAttribute(Attribute::SanitizeAllocToken))
-    return false;
 
   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
   auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
   SmallVector<std::pair<CallBase *, LibFunc>, 4> AllocCalls;
+  SmallVector<IntrinsicInst *, 4> IntrinsicInsts;
+
+  // Only instrument functions that have the sanitize_alloc_token attribute.
+  const bool InstrumentFunction =
+      F.hasFnAttribute(Attribute::SanitizeAllocToken) &&
+      !F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation);
 
   // Collect all allocation calls to avoid iterator invalidation.
   for (Instruction &I : instructions(F)) {
+    // Collect all alloc_token_* intrinsics.
+    if (auto *II = dyn_cast<IntrinsicInst>(&I);
+        II && II->getIntrinsicID() == Intrinsic::alloc_token_id) {
+      IntrinsicInsts.emplace_back(II);
+      continue;
+    }
+
+    if (!InstrumentFunction)
+      continue;
+
     auto *CB = dyn_cast<CallBase>(&I);
     if (!CB)
       continue;
@@ -359,11 +364,22 @@ bool AllocToken::instrumentFunction(Function &F) {
   }
 
   bool Modified = false;
-  for (auto &[CB, Func] : AllocCalls)
-    Modified |= replaceAllocationCall(CB, Func, ORE, TLI);
 
-  if (Modified)
-    NumFunctionsInstrumented++;
+  if (!AllocCalls.empty()) {
+    for (auto &[CB, Func] : AllocCalls)
+      Modified |= replaceAllocationCall(CB, Func, ORE, TLI);
+    if (Modified)
+      NumFunctionsModified++;
+  }
+
+  if (!IntrinsicInsts.empty()) {
+    for (auto *II : IntrinsicInsts) {
+      replaceIntrinsicInst(II, ORE);
+    }
+    Modified = true;
+    NumFunctionsModified++;
+  }
+
   return Modified;
 }
 
@@ -528,6 +544,16 @@ FunctionCallee AllocToken::getTokenAllocFunction(const CallBase &CB,
   return TokenAlloc;
 }
 
+void AllocToken::replaceIntrinsicInst(IntrinsicInst *II,
+                                      OptimizationRemarkEmitter &ORE) {
+  assert(II->getIntrinsicID() == Intrinsic::alloc_token_id);
+
+  uint64_t TokenID = getToken(*II, ORE);
+  Value *V = ConstantInt::get(IntPtrTy, TokenID);
+  II->replaceAllUsesWith(V);
+  II->eraseFromParent();
+}
+
 } // namespace
 
 AllocTokenPass::AllocTokenPass(AllocTokenOptions Opts)
diff --git a/llvm/test/Instrumentation/AllocToken/intrinsic.ll b/llvm/test/Instrumentation/AllocToken/intrinsic.ll
new file mode 100644
index 0000000000000..13aaa90008a7c
--- /dev/null
+++ b/llvm/test/Instrumentation/AllocToken/intrinsic.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; Test that the alloc-token pass lowers the intrinsic to a constant token ID.
+;
+; RUN: opt < %s -passes=alloc-token -alloc-token-mode=typehashpointersplit -alloc-token-max=2 -S | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+declare i64 @llvm.alloc.token.id.i64(metadata)
+
+define i64 @test_intrinsic_lowering() {
+; CHECK-LABEL: define i64 @test_intrinsic_lowering() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i64 0
+;
+entry:
+  %token_no_ptr = call i64 @llvm.alloc.token.id.i64(metadata !0)
+  ret i64 %token_no_ptr
+}
+
+define i64 @test_intrinsic_lowering_ptr() {
+; CHECK-LABEL: define i64 @test_intrinsic_lowering_ptr() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i64 1
+;
+entry:
+  %token_with_ptr = call i64 @llvm.alloc.token.id.i64(metadata !1)
+  ret i64 %token_with_ptr
+}
+
+!0 = !{!"NoPointerType", i1 false}
+!1 = !{!"PointerType", i1 true}
diff --git a/llvm/test/Instrumentation/AllocToken/intrinsic32.ll b/llvm/test/Instrumentation/AllocToken/intrinsic32.ll
new file mode 100644
index 0000000000000..eb5dbbe91a83e
--- /dev/null
+++ b/llvm/test/Instrumentation/AllocToken/intrinsic32.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; Test that the alloc-token pass lowers the intrinsic to a constant token ID.
+;
+; RUN: opt < %s -passes=alloc-token -alloc-token-mode=typehashpointersplit -alloc-token-max=2 -S | FileCheck %s
+
+target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:32:32-n8:16:32-S128"
+target triple = "i386-pc-linux-gnu"
+
+declare i32 @llvm.alloc.token.id.i32(metadata)
+
+define i32 @test_intrinsic_lowering() {
+; CHECK-LABEL: define i32 @test_intrinsic_lowering() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %token_no_ptr = call i32 @llvm.alloc.token.id.i32(metadata !0)
+  ret i32 %token_no_ptr
+}
+
+define i32 @test_intrinsic_lowering_ptr() {
+; CHECK-LABEL: define i32 @test_intrinsic_lowering_ptr() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret i32 1
+;
+entry:
+  %token_with_ptr = call i32 @llvm.alloc.token.id.i32(metadata !1)
+  ret i32 %token_with_ptr
+}
+
+!0 = !{!"NoPointerType", i1 false}
+!1 = !{!"PointerType", i1 true}

>From 49f212489565183cfc8e234fa265ca79ff61948a Mon Sep 17 00:00:00 2001
From: Marco Elver <elver at google.com>
Date: Wed, 15 Oct 2025 23:46:51 +0200
Subject: [PATCH 2/3] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20introduced=20through=20rebase?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.8-beta.1

[skip ci]
---
 llvm/lib/Support/CMakeLists.txt                   | 2 +-
 llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn | 1 +
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index f06b460a0113d..671a5fe941cef 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -147,9 +147,9 @@ add_llvm_component_library(LLVMSupport
   ARMBuildAttributes.cpp
   AArch64AttributeParser.cpp
   AArch64BuildAttributes.cpp
-  AllocToken.cpp
   ARMAttributeParser.cpp
   ARMWinEH.cpp
+  AllocToken.cpp
   Allocator.cpp
   AutoConvert.cpp
   Base64.cpp
diff --git a/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn b/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
index 38ba4661daacc..df9ddf91f2c49 100644
--- a/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
@@ -45,6 +45,7 @@ static_library("Support") {
     "ARMAttributeParser.cpp",
     "ARMBuildAttributes.cpp",
     "ARMWinEH.cpp",
+    "AllocToken.cpp",
     "Allocator.cpp",
     "AutoConvert.cpp",
     "BalancedPartitioning.cpp",

>From 165cbbcb26651c527f0924f8f7d419cdfe4fe3fa Mon Sep 17 00:00:00 2001
From: Marco Elver <elver at google.com>
Date: Fri, 17 Oct 2025 20:14:21 +0200
Subject: [PATCH 3/3] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20introduced=20through=20rebase?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.8-beta.1

[skip ci]
---
 llvm/include/llvm/Support/AllocToken.h        |  6 ++---
 llvm/lib/Support/AllocToken.cpp               | 24 +++++++++++--------
 .../Transforms/Instrumentation/AllocToken.cpp | 12 ++++------
 3 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/Support/AllocToken.h b/llvm/include/llvm/Support/AllocToken.h
index 6617b7d1f7668..8d82670eaeb8d 100644
--- a/llvm/include/llvm/Support/AllocToken.h
+++ b/llvm/include/llvm/Support/AllocToken.h
@@ -49,9 +49,9 @@ struct AllocTokenMetadata {
 /// \param Metadata The metadata about the allocation.
 /// \param MaxTokens The maximum number of tokens (must not be 0)
 /// \return The calculated allocation token ID, or std::nullopt.
-std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
-                                          const AllocTokenMetadata &Metadata,
-                                          uint64_t MaxTokens);
+LLVM_ABI std::optional<uint64_t>
+getAllocToken(AllocTokenMode Mode, const AllocTokenMetadata &Metadata,
+              uint64_t MaxTokens);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Support/AllocToken.cpp b/llvm/lib/Support/AllocToken.cpp
index 6c6f80ac4997c..95ecda2ffd8ba 100644
--- a/llvm/lib/Support/AllocToken.cpp
+++ b/llvm/lib/Support/AllocToken.cpp
@@ -14,11 +14,17 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/SipHash.h"
 
-namespace llvm {
-std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
-                                          const AllocTokenMetadata &Metadata,
-                                          uint64_t MaxTokens) {
-  assert(MaxTokens && "Must provide concrete max tokens");
+using namespace llvm;
+
+static uint64_t getStableHash(const AllocTokenMetadata &Metadata,
+                              uint64_t MaxTokens) {
+  return getStableSipHash(Metadata.TypeName) % MaxTokens;
+}
+
+std::optional<uint64_t> llvm::getAllocToken(AllocTokenMode Mode,
+                                            const AllocTokenMetadata &Metadata,
+                                            uint64_t MaxTokens) {
+  assert(MaxTokens && "Must provide non-zero max tokens");
 
   switch (Mode) {
   case AllocTokenMode::Increment:
@@ -26,15 +32,14 @@ std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
     // Stateful modes cannot be implemented as a pure function.
     return std::nullopt;
 
-  case AllocTokenMode::TypeHash: {
-    return getStableSipHash(Metadata.TypeName) % MaxTokens;
-  }
+  case AllocTokenMode::TypeHash:
+    return getStableHash(Metadata, MaxTokens);
 
   case AllocTokenMode::TypeHashPointerSplit: {
     if (MaxTokens == 1)
       return 0;
     const uint64_t HalfTokens = MaxTokens / 2;
-    uint64_t Hash = getStableSipHash(Metadata.TypeName) % HalfTokens;
+    uint64_t Hash = getStableHash(Metadata, HalfTokens);
     if (Metadata.ContainsPointer)
       Hash += HalfTokens;
     return Hash;
@@ -43,4 +48,3 @@ std::optional<uint64_t> getAllocTokenHash(AllocTokenMode Mode,
 
   llvm_unreachable("");
 }
-} // namespace llvm
diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
index 7c488ec96120e..08738450e4b74 100644
--- a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
@@ -205,8 +205,7 @@ class TypeHashMode : public ModeBase {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
       AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
-      if (auto Token =
-              getAllocTokenHash(TokenMode::TypeHash, Metadata, MaxTokens))
+      if (auto Token = getAllocToken(TokenMode::TypeHash, Metadata, MaxTokens))
         return *Token;
     }
     // Fallback.
@@ -238,8 +237,8 @@ class TypeHashPointerSplitMode : public TypeHashMode {
     if (MDNode *N = getAllocTokenMetadata(CB)) {
       MDString *S = cast<MDString>(N->getOperand(0));
       AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
-      if (auto Token = getAllocTokenHash(TokenMode::TypeHashPointerSplit,
-                                         Metadata, MaxTokens))
+      if (auto Token = getAllocToken(TokenMode::TypeHashPointerSplit, Metadata,
+                                     MaxTokens))
         return *Token;
     }
     // Pick the fallback token (ClFallbackToken), which by default is 0, meaning
@@ -373,9 +372,8 @@ bool AllocToken::instrumentFunction(Function &F) {
   }
 
   if (!IntrinsicInsts.empty()) {
-    for (auto *II : IntrinsicInsts) {
+    for (auto *II : IntrinsicInsts)
       replaceIntrinsicInst(II, ORE);
-    }
     Modified = true;
     NumFunctionsModified++;
   }
@@ -397,7 +395,7 @@ AllocToken::shouldInstrumentCall(const CallBase &CB,
   if (TLI.getLibFunc(*Callee, Func)) {
     if (isInstrumentableLibFunc(Func, CB, TLI))
       return Func;
-  } else if (Options.Extended && getAllocTokenMetadata(CB)) {
+  } else if (Options.Extended && CB.getMetadata(LLVMContext::MD_alloc_token)) {
     return NotLibFunc;
   }
 



More information about the llvm-commits mailing list