[llvm] 0fddaf0 - [Clang] Refactor allocation type inference logic (#163636)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 23 02:22:16 PDT 2025
Author: Marco Elver
Date: 2025-10-23T11:22:13+02:00
New Revision: 0fddaf058ac0c2627553b65ed7d057719d31aa7c
URL: https://github.com/llvm/llvm-project/commit/0fddaf058ac0c2627553b65ed7d057719d31aa7c
DIFF: https://github.com/llvm/llvm-project/commit/0fddaf058ac0c2627553b65ed7d057719d31aa7c.diff
LOG: [Clang] Refactor allocation type inference logic (#163636)
Refactor the logic for inferring allocated types out of `CodeGen` and
into a new reusable component in `clang/AST/InferAlloc.h`.
This is a preparatory step for implementing `__builtin_infer_alloc_token`.
By moving the type inference heuristics into a common place, it can be
shared between the existing allocation-call instrumentation and the new
builtin's implementation.
Added:
clang/include/clang/AST/InferAlloc.h
clang/lib/AST/InferAlloc.cpp
Modified:
clang/lib/AST/CMakeLists.txt
clang/lib/CodeGen/CGExpr.cpp
clang/lib/CodeGen/CodeGenFunction.h
llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
Removed:
################################################################################
diff --git a/clang/include/clang/AST/InferAlloc.h b/clang/include/clang/AST/InferAlloc.h
new file mode 100644
index 0000000000000..c3dc30204feaf
--- /dev/null
+++ b/clang/include/clang/AST/InferAlloc.h
@@ -0,0 +1,35 @@
+//===--- InferAlloc.h - Allocation type inference ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for allocation-related type inference.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_INFERALLOC_H
+#define LLVM_CLANG_AST_INFERALLOC_H
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Expr.h"
+#include "llvm/Support/AllocToken.h"
+#include <optional>
+
+namespace clang {
+namespace infer_alloc {
+
+/// Infer the possible allocated type from an allocation call expression.
+QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
+ const CastExpr *CastE);
+
+/// Get the information required for construction of an allocation token ID.
+std::optional<llvm::AllocTokenMetadata>
+getAllocTokenMetadata(QualType T, const ASTContext &Ctx);
+
+} // namespace infer_alloc
+} // namespace clang
+
+#endif // LLVM_CLANG_AST_INFERALLOC_H
diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt
index d4fd7a7f16d53..fd50e956bb865 100644
--- a/clang/lib/AST/CMakeLists.txt
+++ b/clang/lib/AST/CMakeLists.txt
@@ -66,6 +66,7 @@ add_clang_library(clangAST
ExternalASTMerger.cpp
ExternalASTSource.cpp
FormatString.cpp
+ InferAlloc.cpp
InheritViz.cpp
ByteCode/BitcastBuffer.cpp
ByteCode/ByteCodeEmitter.cpp
diff --git a/clang/lib/AST/InferAlloc.cpp b/clang/lib/AST/InferAlloc.cpp
new file mode 100644
index 0000000000000..e439ed4dbb386
--- /dev/null
+++ b/clang/lib/AST/InferAlloc.cpp
@@ -0,0 +1,201 @@
+//===--- InferAlloc.cpp - Allocation type inference -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements allocation-related type inference.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/InferAlloc.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/IdentifierTable.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+using namespace clang;
+using namespace infer_alloc;
+
+static bool
+typeContainsPointer(QualType T,
+ llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
+ bool &IncompleteType) {
+ QualType CanonicalType = T.getCanonicalType();
+ if (CanonicalType->isPointerType())
+ return true; // base case
+
+ // Look through typedef chain to check for special types.
+ for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
+ CurrentT = TT->getDecl()->getUnderlyingType()) {
+ const IdentifierInfo *II = TT->getDecl()->getIdentifier();
+ // Special Case: Syntactically uintptr_t is not a pointer; semantically,
+ // however, very likely used as such. Therefore, classify uintptr_t as a
+ // pointer, too.
+ if (II && II->isStr("uintptr_t"))
+ return true;
+ }
+
+ // The type is an array; check the element type.
+ if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType))
+ return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType);
+ // The type is a struct, class, or union.
+ if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
+ if (!RD->isCompleteDefinition()) {
+ IncompleteType = true;
+ return false;
+ }
+ if (!VisitedRD.insert(RD).second)
+ return false; // already visited
+ // Check all fields.
+ for (const FieldDecl *Field : RD->fields()) {
+ if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType))
+ return true;
+ }
+ // For C++ classes, also check base classes.
+ if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // Polymorphic types require a vptr.
+ if (CXXRD->isDynamicClass())
+ return true;
+ for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
+ if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType))
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+/// Infer type from a simple sizeof expression.
+static QualType inferTypeFromSizeofExpr(const Expr *E) {
+ const Expr *Arg = E->IgnoreParenImpCasts();
+ if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
+ if (UET->getKind() == UETT_SizeOf) {
+ if (UET->isArgumentType())
+ return UET->getArgumentTypeInfo()->getType();
+ else
+ return UET->getArgumentExpr()->getType();
+ }
+ }
+ return QualType();
+}
+
+/// Infer type from an arithmetic expression involving a sizeof. For example:
+///
+/// malloc(sizeof(MyType) + padding); // infers 'MyType'
+/// malloc(sizeof(MyType) * 32); // infers 'MyType'
+/// malloc(32 * sizeof(MyType)); // infers 'MyType'
+/// malloc(sizeof(MyType) << 1); // infers 'MyType'
+/// ...
+///
+/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
+/// when considering allocations for structs with flexible array members:
+///
+/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray'
+///
+static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
+ const Expr *Arg = E->IgnoreParenImpCasts();
+ // The argument is a lone sizeof expression.
+ if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
+ return T;
+ if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) {
+ // Argument is an arithmetic expression. Cover common arithmetic patterns
+ // involving sizeof.
+ switch (BO->getOpcode()) {
+ case BO_Add:
+ case BO_Div:
+ case BO_Mul:
+ case BO_Shl:
+ case BO_Shr:
+ case BO_Sub:
+ if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS());
+ !T.isNull())
+ return T;
+ if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS());
+ !T.isNull())
+ return T;
+ break;
+ default:
+ break;
+ }
+ }
+ return QualType();
+}
+
+/// If the expression E is a reference to a variable, infer the type from a
+/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
+/// and ignores if a variable is later reassigned. For example:
+///
+/// size_t my_size = sizeof(MyType);
+/// void *x = malloc(my_size); // infers 'MyType'
+///
+static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
+ const Expr *Arg = E->IgnoreParenImpCasts();
+ if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
+ if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
+ if (const Expr *Init = VD->getInit())
+ return inferPossibleTypeFromArithSizeofExpr(Init);
+ }
+ }
+ return QualType();
+}
+
+/// Deduces the allocated type by checking if the allocation call's result
+/// is immediately used in a cast expression. For example:
+///
+/// MyType *x = (MyType *)malloc(4096); // infers 'MyType'
+///
+static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
+ const CastExpr *CastE) {
+ if (!CastE)
+ return QualType();
+ QualType PtrType = CastE->getType();
+ if (PtrType->isPointerType())
+ return PtrType->getPointeeType();
+ return QualType();
+}
+
+QualType infer_alloc::inferPossibleType(const CallExpr *E,
+ const ASTContext &Ctx,
+ const CastExpr *CastE) {
+ QualType AllocType;
+ // First check arguments.
+ for (const Expr *Arg : E->arguments()) {
+ AllocType = inferPossibleTypeFromArithSizeofExpr(Arg);
+ if (AllocType.isNull())
+ AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg);
+ if (!AllocType.isNull())
+ break;
+ }
+ // Then check later casts.
+ if (AllocType.isNull())
+ AllocType = inferPossibleTypeFromCastExpr(E, CastE);
+ return AllocType;
+}
+
+std::optional<llvm::AllocTokenMetadata>
+infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
+ llvm::AllocTokenMetadata ATMD;
+
+ // Get unique type name.
+ PrintingPolicy Policy(Ctx.getLangOpts());
+ Policy.SuppressTagKeyword = true;
+ Policy.FullyQualifiedName = true;
+ llvm::raw_svector_ostream TypeNameOS(ATMD.TypeName);
+ T.getCanonicalType().print(TypeNameOS, Policy);
+
+ // Check if QualType contains a pointer. Implements a simple DFS to
+ // recursively check if a type contains a pointer type.
+ llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
+ bool IncompleteType = false;
+ ATMD.ContainsPointer = typeContainsPointer(T, VisitedRD, IncompleteType);
+ if (!ATMD.ContainsPointer && IncompleteType)
+ return std::nullopt;
+
+ return ATMD;
+}
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index fd73314c9f84c..301d5770cf78f 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -29,6 +29,7 @@
#include "clang/AST/ASTLambda.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclObjC.h"
+#include "clang/AST/InferAlloc.h"
#include "clang/AST/NSAPI.h"
#include "clang/AST/ParentMapContext.h"
#include "clang/AST/StmtVisitor.h"
@@ -1273,194 +1274,39 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound,
EmitCheck(std::make_pair(Check, CheckKind), CheckHandler, StaticData, Index);
}
-static bool
-typeContainsPointer(QualType T,
- llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
- bool &IncompleteType) {
- QualType CanonicalType = T.getCanonicalType();
- if (CanonicalType->isPointerType())
- return true; // base case
-
- // Look through typedef chain to check for special types.
- for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
- CurrentT = TT->getDecl()->getUnderlyingType()) {
- const IdentifierInfo *II = TT->getDecl()->getIdentifier();
- // Special Case: Syntactically uintptr_t is not a pointer; semantically,
- // however, very likely used as such. Therefore, classify uintptr_t as a
- // pointer, too.
- if (II && II->isStr("uintptr_t"))
- return true;
- }
-
- // The type is an array; check the element type.
- if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType))
- return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType);
- // The type is a struct, class, or union.
- if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
- if (!RD->isCompleteDefinition()) {
- IncompleteType = true;
- return false;
- }
- if (!VisitedRD.insert(RD).second)
- return false; // already visited
- // Check all fields.
- for (const FieldDecl *Field : RD->fields()) {
- if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType))
- return true;
- }
- // For C++ classes, also check base classes.
- if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
- // Polymorphic types require a vptr.
- if (CXXRD->isDynamicClass())
- return true;
- for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
- if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType))
- return true;
- }
- }
- }
- return false;
-}
-
-void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) {
- assert(SanOpts.has(SanitizerKind::AllocToken) &&
- "Only needed with -fsanitize=alloc-token");
+llvm::MDNode *CodeGenFunction::buildAllocToken(QualType AllocType) {
+ auto ATMD = infer_alloc::getAllocTokenMetadata(AllocType, getContext());
+ if (!ATMD)
+ return nullptr;
llvm::MDBuilder MDB(getLLVMContext());
-
- // Get unique type name.
- PrintingPolicy Policy(CGM.getContext().getLangOpts());
- Policy.SuppressTagKeyword = true;
- Policy.FullyQualifiedName = true;
- SmallString<64> TypeName;
- llvm::raw_svector_ostream TypeNameOS(TypeName);
- AllocType.getCanonicalType().print(TypeNameOS, Policy);
- auto *TypeNameMD = MDB.createString(TypeNameOS.str());
-
- // Check if QualType contains a pointer. Implements a simple DFS to
- // recursively check if a type contains a pointer type.
- llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
- bool IncompleteType = false;
- const bool ContainsPtr =
- typeContainsPointer(AllocType, VisitedRD, IncompleteType);
- if (!ContainsPtr && IncompleteType)
- return;
- auto *ContainsPtrC = Builder.getInt1(ContainsPtr);
+ auto *TypeNameMD = MDB.createString(ATMD->TypeName);
+ auto *ContainsPtrC = Builder.getInt1(ATMD->ContainsPointer);
auto *ContainsPtrMD = MDB.createConstant(ContainsPtrC);
// Format: !{<type-name>, <contains-pointer>}
- auto *MDN =
- llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD});
- CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN);
-}
-
-namespace {
-/// Infer type from a simple sizeof expression.
-QualType inferTypeFromSizeofExpr(const Expr *E) {
- const Expr *Arg = E->IgnoreParenImpCasts();
- if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
- if (UET->getKind() == UETT_SizeOf) {
- if (UET->isArgumentType())
- return UET->getArgumentTypeInfo()->getType();
- else
- return UET->getArgumentExpr()->getType();
- }
- }
- return QualType();
-}
-
-/// Infer type from an arithmetic expression involving a sizeof. For example:
-///
-/// malloc(sizeof(MyType) + padding); // infers 'MyType'
-/// malloc(sizeof(MyType) * 32); // infers 'MyType'
-/// malloc(32 * sizeof(MyType)); // infers 'MyType'
-/// malloc(sizeof(MyType) << 1); // infers 'MyType'
-/// ...
-///
-/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
-/// when considering allocations for structs with flexible array members:
-///
-/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray'
-///
-QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
- const Expr *Arg = E->IgnoreParenImpCasts();
- // The argument is a lone sizeof expression.
- if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
- return T;
- if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) {
- // Argument is an arithmetic expression. Cover common arithmetic patterns
- // involving sizeof.
- switch (BO->getOpcode()) {
- case BO_Add:
- case BO_Div:
- case BO_Mul:
- case BO_Shl:
- case BO_Shr:
- case BO_Sub:
- if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS());
- !T.isNull())
- return T;
- if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS());
- !T.isNull())
- return T;
- break;
- default:
- break;
- }
- }
- return QualType();
+ return llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD});
}
-/// If the expression E is a reference to a variable, infer the type from a
-/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
-/// and ignores if a variable is later reassigned. For example:
-///
-/// size_t my_size = sizeof(MyType);
-/// void *x = malloc(my_size); // infers 'MyType'
-///
-QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
- const Expr *Arg = E->IgnoreParenImpCasts();
- if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
- if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
- if (const Expr *Init = VD->getInit())
- return inferPossibleTypeFromArithSizeofExpr(Init);
- }
- }
- return QualType();
+void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) {
+ assert(SanOpts.has(SanitizerKind::AllocToken) &&
+ "Only needed with -fsanitize=alloc-token");
+ CB->setMetadata(llvm::LLVMContext::MD_alloc_token,
+ buildAllocToken(AllocType));
}
-/// Deduces the allocated type by checking if the allocation call's result
-/// is immediately used in a cast expression. For example:
-///
-/// MyType *x = (MyType *)malloc(4096); // infers 'MyType'
-///
-QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
- const CastExpr *CastE) {
- if (!CastE)
- return QualType();
- QualType PtrType = CastE->getType();
- if (PtrType->isPointerType())
- return PtrType->getPointeeType();
- return QualType();
+llvm::MDNode *CodeGenFunction::buildAllocToken(const CallExpr *E) {
+ QualType AllocType = infer_alloc::inferPossibleType(E, getContext(), CurCast);
+ if (!AllocType.isNull())
+ return buildAllocToken(AllocType);
+ return nullptr;
}
-} // end anonymous namespace
void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, const CallExpr *E) {
- QualType AllocType;
- // First check arguments.
- for (const Expr *Arg : E->arguments()) {
- AllocType = inferPossibleTypeFromArithSizeofExpr(Arg);
- if (AllocType.isNull())
- AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg);
- if (!AllocType.isNull())
- break;
- }
- // Then check later casts.
- if (AllocType.isNull())
- AllocType = inferPossibleTypeFromCastExpr(E, CurCast);
- // Emit if we were able to infer the type.
- if (!AllocType.isNull())
- EmitAllocToken(CB, AllocType);
+ assert(SanOpts.has(SanitizerKind::AllocToken) &&
+ "Only needed with -fsanitize=alloc-token");
+ if (llvm::MDNode *MDN = buildAllocToken(E))
+ CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN);
}
CodeGenFunction::ComplexPairTy CodeGenFunction::
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1f0be2d8756de..8c4c1c8c2dc95 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3352,9 +3352,14 @@ class CodeGenFunction : public CodeGenTypeCache {
SanitizerAnnotateDebugInfo(ArrayRef<SanitizerKind::SanitizerOrdinal> Ordinals,
SanitizerHandler Handler);
- /// Emit additional metadata used by the AllocToken instrumentation.
+ /// Build metadata used by the AllocToken instrumentation.
+ llvm::MDNode *buildAllocToken(QualType AllocType);
+ /// Emit and set additional metadata used by the AllocToken instrumentation.
void EmitAllocToken(llvm::CallBase *CB, QualType AllocType);
- /// Emit additional metadata used by the AllocToken instrumentation,
+ /// Build additional metadata used by the AllocToken instrumentation,
+ /// inferring the type from an allocation call expression.
+ llvm::MDNode *buildAllocToken(const CallExpr *E);
+ /// Emit and set additional metadata used by the AllocToken instrumentation,
/// inferring the type from an allocation call expression.
void EmitAllocToken(llvm::CallBase *CB, const CallExpr *E);
diff --git a/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn b/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
index 9981d100fd555..4da907cbdd938 100644
--- a/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
+++ b/llvm/utils/gn/secondary/clang/lib/AST/BUILD.gn
@@ -121,6 +121,7 @@ static_library("AST") {
"ExternalASTMerger.cpp",
"ExternalASTSource.cpp",
"FormatString.cpp",
+ "InferAlloc.cpp",
"InheritViz.cpp",
"ItaniumCXXABI.cpp",
"ItaniumMangle.cpp",
More information about the llvm-commits
mailing list