[llvm] 0fddaf0 - [Clang] Refactor allocation type inference logic (#163636)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 23 02:22:16 PDT 2025


Author: Marco Elver
Date: 2025-10-23T11:22:13+02:00
New Revision: 0fddaf058ac0c2627553b65ed7d057719d31aa7c

URL: https://github.com/llvm/llvm-project/commit/0fddaf058ac0c2627553b65ed7d057719d31aa7c
DIFF: https://github.com/llvm/llvm-project/commit/0fddaf058ac0c2627553b65ed7d057719d31aa7c.diff

LOG: [Clang] Refactor allocation type inference logic (#163636)

Refactor the logic for inferring allocated types out of `CodeGen` and
into a new reusable component in `clang/AST/InferAlloc.h`.

This is a preparatory step for implementing `__builtin_infer_alloc_token`.
By moving the type inference heuristics into a common place, it can be
shared between the existing allocation-call instrumentation and the new
builtin's implementation.

Added: 
    clang/include/clang/AST/InferAlloc.h
    clang/lib/AST/InferAlloc.cpp

Modified: 
    clang/lib/AST/CMakeLists.txt
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/CodeGen/CodeGenFunction.h
    llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/InferAlloc.h b/clang/include/clang/AST/InferAlloc.h
new file mode 100644
index 0000000000000..c3dc30204feaf
--- /dev/null
+++ b/clang/include/clang/AST/InferAlloc.h
@@ -0,0 +1,35 @@
+//===--- InferAlloc.h - Allocation type inference ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for allocation-related type inference.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_INFERALLOC_H
+#define LLVM_CLANG_AST_INFERALLOC_H
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Expr.h"
+#include "llvm/Support/AllocToken.h"
+#include <optional>
+
+namespace clang {
+namespace infer_alloc {
+
+/// Infer the possible allocated type from an allocation call expression.
+QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
+                           const CastExpr *CastE);
+
+/// Get the information required for construction of an allocation token ID.
+std::optional<llvm::AllocTokenMetadata>
+getAllocTokenMetadata(QualType T, const ASTContext &Ctx);
+
+} // namespace infer_alloc
+} // namespace clang
+
+#endif // LLVM_CLANG_AST_INFERALLOC_H

diff  --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt
index d4fd7a7f16d53..fd50e956bb865 100644
--- a/clang/lib/AST/CMakeLists.txt
+++ b/clang/lib/AST/CMakeLists.txt
@@ -66,6 +66,7 @@ add_clang_library(clangAST
   ExternalASTMerger.cpp
   ExternalASTSource.cpp
   FormatString.cpp
+  InferAlloc.cpp
   InheritViz.cpp
   ByteCode/BitcastBuffer.cpp
   ByteCode/ByteCodeEmitter.cpp

diff  --git a/clang/lib/AST/InferAlloc.cpp b/clang/lib/AST/InferAlloc.cpp
new file mode 100644
index 0000000000000..e439ed4dbb386
--- /dev/null
+++ b/clang/lib/AST/InferAlloc.cpp
@@ -0,0 +1,201 @@
+//===--- InferAlloc.cpp - Allocation type inference -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements allocation-related type inference.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/InferAlloc.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/IdentifierTable.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+using namespace clang;
+using namespace infer_alloc;
+
+static bool
+typeContainsPointer(QualType T,
+                    llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
+                    bool &IncompleteType) {
+  QualType CanonicalType = T.getCanonicalType();
+  if (CanonicalType->isPointerType())
+    return true; // base case
+
+  // Look through typedef chain to check for special types.
+  for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
+       CurrentT = TT->getDecl()->getUnderlyingType()) {
+    const IdentifierInfo *II = TT->getDecl()->getIdentifier();
+    // Special Case: Syntactically uintptr_t is not a pointer; semantically,
+    // however, very likely used as such. Therefore, classify uintptr_t as a
+    // pointer, too.
+    if (II && II->isStr("uintptr_t"))
+      return true;
+  }
+
+  // The type is an array; check the element type.
+  if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType))
+    return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType);
+  // The type is a struct, class, or union.
+  if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
+    if (!RD->isCompleteDefinition()) {
+      IncompleteType = true;
+      return false;
+    }
+    if (!VisitedRD.insert(RD).second)
+      return false; // already visited
+    // Check all fields.
+    for (const FieldDecl *Field : RD->fields()) {
+      if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType))
+        return true;
+    }
+    // For C++ classes, also check base classes.
+    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+      // Polymorphic types require a vptr.
+      if (CXXRD->isDynamicClass())
+        return true;
+      for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
+        if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType))
+          return true;
+      }
+    }
+  }
+  return false;
+}
+
+/// Infer type from a simple sizeof expression.
+static QualType inferTypeFromSizeofExpr(const Expr *E) {
+  const Expr *Arg = E->IgnoreParenImpCasts();
+  if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
+    if (UET->getKind() == UETT_SizeOf) {
+      if (UET->isArgumentType())
+        return UET->getArgumentTypeInfo()->getType();
+      else
+        return UET->getArgumentExpr()->getType();
+    }
+  }
+  return QualType();
+}
+
+/// Infer type from an arithmetic expression involving a sizeof. For example:
+///
+///   malloc(sizeof(MyType) + padding);  // infers 'MyType'
+///   malloc(sizeof(MyType) * 32);       // infers 'MyType'
+///   malloc(32 * sizeof(MyType));       // infers 'MyType'
+///   malloc(sizeof(MyType) << 1);       // infers 'MyType'
+///   ...
+///
+/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
+/// when considering allocations for structs with flexible array members:
+///
+///   malloc(sizeof(HasFlexArray) + sizeof(int) * 32);  // infers 'HasFlexArray'
+///
+static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
+  const Expr *Arg = E->IgnoreParenImpCasts();
+  // The argument is a lone sizeof expression.
+  if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
+    return T;
+  if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) {
+    // Argument is an arithmetic expression. Cover common arithmetic patterns
+    // involving sizeof.
+    switch (BO->getOpcode()) {
+    case BO_Add:
+    case BO_Div:
+    case BO_Mul:
+    case BO_Shl:
+    case BO_Shr:
+    case BO_Sub:
+      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS());
+          !T.isNull())
+        return T;
+      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS());
+          !T.isNull())
+        return T;
+      break;
+    default:
+      break;
+    }
+  }
+  return QualType();
+}
+
+/// If the expression E is a reference to a variable, infer the type from a
+/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
+/// and ignores if a variable is later reassigned. For example:
+///
+///   size_t my_size = sizeof(MyType);
+///   void *x = malloc(my_size);  // infers 'MyType'
+///
+static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
+  const Expr *Arg = E->IgnoreParenImpCasts();
+  if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
+    if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
+      if (const Expr *Init = VD->getInit())
+        return inferPossibleTypeFromArithSizeofExpr(Init);
+    }
+  }
+  return QualType();
+}
+
+/// Deduces the allocated type by checking if the allocation call's result
+/// is immediately used in a cast expression. For example:
+///
+///   MyType *x = (MyType *)malloc(4096);  // infers 'MyType'
+///
+static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
+                                              const CastExpr *CastE) {
+  if (!CastE)
+    return QualType();
+  QualType PtrType = CastE->getType();
+  if (PtrType->isPointerType())
+    return PtrType->getPointeeType();
+  return QualType();
+}
+
+QualType infer_alloc::inferPossibleType(const CallExpr *E,
+                                        const ASTContext &Ctx,
+                                        const CastExpr *CastE) {
+  QualType AllocType;
+  // First check arguments.
+  for (const Expr *Arg : E->arguments()) {
+    AllocType = inferPossibleTypeFromArithSizeofExpr(Arg);
+    if (AllocType.isNull())
+      AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg);
+    if (!AllocType.isNull())
+      break;
+  }
+  // Then check later casts.
+  if (AllocType.isNull())
+    AllocType = inferPossibleTypeFromCastExpr(E, CastE);
+  return AllocType;
+}
+
+std::optional<llvm::AllocTokenMetadata>
+infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
+  llvm::AllocTokenMetadata ATMD;
+
+  // Get unique type name.
+  PrintingPolicy Policy(Ctx.getLangOpts());
+  Policy.SuppressTagKeyword = true;
+  Policy.FullyQualifiedName = true;
+  llvm::raw_svector_ostream TypeNameOS(ATMD.TypeName);
+  T.getCanonicalType().print(TypeNameOS, Policy);
+
+  // Check if QualType contains a pointer. Implements a simple DFS to
+  // recursively check if a type contains a pointer type.
+  llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
+  bool IncompleteType = false;
+  ATMD.ContainsPointer = typeContainsPointer(T, VisitedRD, IncompleteType);
+  if (!ATMD.ContainsPointer && IncompleteType)
+    return std::nullopt;
+
+  return ATMD;
+}

diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index fd73314c9f84c..301d5770cf78f 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -29,6 +29,7 @@
 #include "clang/AST/ASTLambda.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/DeclObjC.h"
+#include "clang/AST/InferAlloc.h"
 #include "clang/AST/NSAPI.h"
 #include "clang/AST/ParentMapContext.h"
 #include "clang/AST/StmtVisitor.h"
@@ -1273,194 +1274,39 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound,
   EmitCheck(std::make_pair(Check, CheckKind), CheckHandler, StaticData, Index);
 }
 
-static bool
-typeContainsPointer(QualType T,
-                    llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
-                    bool &IncompleteType) {
-  QualType CanonicalType = T.getCanonicalType();
-  if (CanonicalType->isPointerType())
-    return true; // base case
-
-  // Look through typedef chain to check for special types.
-  for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
-       CurrentT = TT->getDecl()->getUnderlyingType()) {
-    const IdentifierInfo *II = TT->getDecl()->getIdentifier();
-    // Special Case: Syntactically uintptr_t is not a pointer; semantically,
-    // however, very likely used as such. Therefore, classify uintptr_t as a
-    // pointer, too.
-    if (II && II->isStr("uintptr_t"))
-      return true;
-  }
-
-  // The type is an array; check the element type.
-  if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType))
-    return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType);
-  // The type is a struct, class, or union.
-  if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
-    if (!RD->isCompleteDefinition()) {
-      IncompleteType = true;
-      return false;
-    }
-    if (!VisitedRD.insert(RD).second)
-      return false; // already visited
-    // Check all fields.
-    for (const FieldDecl *Field : RD->fields()) {
-      if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType))
-        return true;
-    }
-    // For C++ classes, also check base classes.
-    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
-      // Polymorphic types require a vptr.
-      if (CXXRD->isDynamicClass())
-        return true;
-      for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
-        if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType))
-          return true;
-      }
-    }
-  }
-  return false;
-}
-
-void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) {
-  assert(SanOpts.has(SanitizerKind::AllocToken) &&
-         "Only needed with -fsanitize=alloc-token");
+llvm::MDNode *CodeGenFunction::buildAllocToken(QualType AllocType) {
+  auto ATMD = infer_alloc::getAllocTokenMetadata(AllocType, getContext());
+  if (!ATMD)
+    return nullptr;
 
   llvm::MDBuilder MDB(getLLVMContext());
-
-  // Get unique type name.
-  PrintingPolicy Policy(CGM.getContext().getLangOpts());
-  Policy.SuppressTagKeyword = true;
-  Policy.FullyQualifiedName = true;
-  SmallString<64> TypeName;
-  llvm::raw_svector_ostream TypeNameOS(TypeName);
-  AllocType.getCanonicalType().print(TypeNameOS, Policy);
-  auto *TypeNameMD = MDB.createString(TypeNameOS.str());
-
-  // Check if QualType contains a pointer. Implements a simple DFS to
-  // recursively check if a type contains a pointer type.
-  llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
-  bool IncompleteType = false;
-  const bool ContainsPtr =
-      typeContainsPointer(AllocType, VisitedRD, IncompleteType);
-  if (!ContainsPtr && IncompleteType)
-    return;
-  auto *ContainsPtrC = Builder.getInt1(ContainsPtr);
+  auto *TypeNameMD = MDB.createString(ATMD->TypeName);
+  auto *ContainsPtrC = Builder.getInt1(ATMD->ContainsPointer);
   auto *ContainsPtrMD = MDB.createConstant(ContainsPtrC);
 
   // Format: !{<type-name>, <contains-pointer>}
-  auto *MDN =
-      llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD});
-  CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN);
-}
-
-namespace {
-/// Infer type from a simple sizeof expression.
-QualType inferTypeFromSizeofExpr(const Expr *E) {
-  const Expr *Arg = E->IgnoreParenImpCasts();
-  if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
-    if (UET->getKind() == UETT_SizeOf) {
-      if (UET->isArgumentType())
-        return UET->getArgumentTypeInfo()->getType();
-      else
-        return UET->getArgumentExpr()->getType();
-    }
-  }
-  return QualType();
-}
-
-/// Infer type from an arithmetic expression involving a sizeof. For example:
-///
-///   malloc(sizeof(MyType) + padding);  // infers 'MyType'
-///   malloc(sizeof(MyType) * 32);       // infers 'MyType'
-///   malloc(32 * sizeof(MyType));       // infers 'MyType'
-///   malloc(sizeof(MyType) << 1);       // infers 'MyType'
-///   ...
-///
-/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
-/// when considering allocations for structs with flexible array members:
-///
-///   malloc(sizeof(HasFlexArray) + sizeof(int) * 32);  // infers 'HasFlexArray'
-///
-QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
-  const Expr *Arg = E->IgnoreParenImpCasts();
-  // The argument is a lone sizeof expression.
-  if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
-    return T;
-  if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) {
-    // Argument is an arithmetic expression. Cover common arithmetic patterns
-    // involving sizeof.
-    switch (BO->getOpcode()) {
-    case BO_Add:
-    case BO_Div:
-    case BO_Mul:
-    case BO_Shl:
-    case BO_Shr:
-    case BO_Sub:
-      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS());
-          !T.isNull())
-        return T;
-      if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS());
-          !T.isNull())
-        return T;
-      break;
-    default:
-      break;
-    }
-  }
-  return QualType();
+  return llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD});
 }
 
-/// If the expression E is a reference to a variable, infer the type from a
-/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
-/// and ignores if a variable is later reassigned. For example:
-///
-///   size_t my_size = sizeof(MyType);
-///   void *x = malloc(my_size);  // infers 'MyType'
-///
-QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
-  const Expr *Arg = E->IgnoreParenImpCasts();
-  if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
-    if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
-      if (const Expr *Init = VD->getInit())
-        return inferPossibleTypeFromArithSizeofExpr(Init);
-    }
-  }
-  return QualType();
+void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) {
+  assert(SanOpts.has(SanitizerKind::AllocToken) &&
+         "Only needed with -fsanitize=alloc-token");
+  CB->setMetadata(llvm::LLVMContext::MD_alloc_token,
+                  buildAllocToken(AllocType));
 }
 
-/// Deduces the allocated type by checking if the allocation call's result
-/// is immediately used in a cast expression. For example:
-///
-///   MyType *x = (MyType *)malloc(4096);  // infers 'MyType'
-///
-QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
-                                       const CastExpr *CastE) {
-  if (!CastE)
-    return QualType();
-  QualType PtrType = CastE->getType();
-  if (PtrType->isPointerType())
-    return PtrType->getPointeeType();
-  return QualType();
+llvm::MDNode *CodeGenFunction::buildAllocToken(const CallExpr *E) {
+  QualType AllocType = infer_alloc::inferPossibleType(E, getContext(), CurCast);
+  if (!AllocType.isNull())
+    return buildAllocToken(AllocType);
+  return nullptr;
 }
-} // end anonymous namespace
 
 void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, const CallExpr *E) {
-  QualType AllocType;
-  // First check arguments.
-  for (const Expr *Arg : E->arguments()) {
-    AllocType = inferPossibleTypeFromArithSizeofExpr(Arg);
-    if (AllocType.isNull())
-      AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg);
-    if (!AllocType.isNull())
-      break;
-  }
-  // Then check later casts.
-  if (AllocType.isNull())
-    AllocType = inferPossibleTypeFromCastExpr(E, CurCast);
-  // Emit if we were able to infer the type.
-  if (!AllocType.isNull())
-    EmitAllocToken(CB, AllocType);
+  assert(SanOpts.has(SanitizerKind::AllocToken) &&
+         "Only needed with -fsanitize=alloc-token");
+  if (llvm::MDNode *MDN = buildAllocToken(E))
+    CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN);
 }
 
 CodeGenFunction::ComplexPairTy CodeGenFunction::

diff  --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1f0be2d8756de..8c4c1c8c2dc95 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3352,9 +3352,14 @@ class CodeGenFunction : public CodeGenTypeCache {
   SanitizerAnnotateDebugInfo(ArrayRef<SanitizerKind::SanitizerOrdinal> Ordinals,
                              SanitizerHandler Handler);
 
-  /// Emit additional metadata used by the AllocToken instrumentation.
+  /// Build metadata used by the AllocToken instrumentation.
+  llvm::MDNode *buildAllocToken(QualType AllocType);
+  /// Emit and set additional metadata used by the AllocToken instrumentation.
   void EmitAllocToken(llvm::CallBase *CB, QualType AllocType);
-  /// Emit additional metadata used by the AllocToken instrumentation,
+  /// Build additional metadata used by the AllocToken instrumentation,
+  /// inferring the type from an allocation call expression.
+  llvm::MDNode *buildAllocToken(const CallExpr *E);
+  /// Emit and set additional metadata used by the AllocToken instrumentation,
   /// inferring the type from an allocation call expression.
   void EmitAllocToken(llvm::CallBase *CB, const CallExpr *E);
 

diff  --git a/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn b/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
index 9981d100fd555..4da907cbdd938 100644
--- a/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
+++ b/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
@@ -121,6 +121,7 @@ static_library("AST") {
     "ExternalASTMerger.cpp",
     "ExternalASTSource.cpp",
     "FormatString.cpp",
+    "InferAlloc.cpp",
     "InheritViz.cpp",
     "ItaniumCXXABI.cpp",
     "ItaniumMangle.cpp",


        


More information about the llvm-commits mailing list