[clang] [llvm] [LLVM][Clang] Add and enable strict mode for `getTrailingObjects` (PR #144930)
Rahul Joshi via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 27 14:21:37 PDT 2025
https://github.com/jurahul updated https://github.com/llvm/llvm-project/pull/144930
>From ee982b8b2d14b1199f051db53aea4f26899d4d77 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Thu, 19 Jun 2025 10:25:12 -0700
Subject: [PATCH] [LLVM][Clang] Add and enable strict mode for
`getTrailingObjects`
Under strict mode, the templated `getTrailingObjects` can be called
only when there is > 1 trailing types. The strict mode can be disabled
on a per-call basis when its not possible to know statically if there
will be a single or multiple trailing types (like in OpenMPClause.h).
---
clang/include/clang/AST/OpenMPClause.h | 37 ++++++++-----
clang/lib/AST/Expr.cpp | 3 +-
llvm/include/llvm/Support/TrailingObjects.h | 55 ++++++++++++++++---
.../unittests/Support/TrailingObjectsTest.cpp | 7 ++-
4 files changed, 74 insertions(+), 28 deletions(-)
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index c6f99fb21a0f0..5b2206af75bee 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public OMPClause {
/// Fetches list of variables associated with this clause.
MutableArrayRef<Expr *> getVarRefs() {
- return static_cast<T *>(this)->template getTrailingObjects<Expr *>(NumVars);
+ return static_cast<T *>(this)->template getTrailingObjectsNonStrict<Expr *>(
+ NumVars);
}
/// Sets the list of variables for this clause.
@@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public OMPClause {
/// Fetches list of all variables in the clause.
ArrayRef<const Expr *> getVarRefs() const {
- return static_cast<const T *>(this)->template getTrailingObjects<Expr *>(
- NumVars);
+ return static_cast<const T *>(this)
+ ->template getTrailingObjectsNonStrict<Expr *>(NumVars);
}
};
@@ -380,7 +381,7 @@ template <class T> class OMPDirectiveListClause : public OMPClause {
MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() {
return static_cast<T *>(this)
- ->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds);
+ ->template getTrailingObjectsNonStrict<OpenMPDirectiveKind>(NumKinds);
}
void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) {
@@ -5921,15 +5922,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the unique declarations that are in the trailing objects of the
/// class.
MutableArrayRef<ValueDecl *> getUniqueDeclsRef() {
- return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>(
- NumUniqueDeclarations);
+ return static_cast<T *>(this)
+ ->template getTrailingObjectsNonStrict<ValueDecl *>(
+ NumUniqueDeclarations);
}
/// Get the unique declarations that are in the trailing objects of the
/// class.
ArrayRef<ValueDecl *> getUniqueDeclsRef() const {
return static_cast<const T *>(this)
- ->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations);
+ ->template getTrailingObjectsNonStrict<ValueDecl *>(
+ NumUniqueDeclarations);
}
/// Set the unique declarations that are in the trailing objects of the
@@ -5943,15 +5946,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the number of lists per declaration that are in the trailing
/// objects of the class.
MutableArrayRef<unsigned> getDeclNumListsRef() {
- return static_cast<T *>(this)->template getTrailingObjects<unsigned>(
- NumUniqueDeclarations);
+ return static_cast<T *>(this)
+ ->template getTrailingObjectsNonStrict<unsigned>(NumUniqueDeclarations);
}
/// Get the number of lists per declaration that are in the trailing
/// objects of the class.
ArrayRef<unsigned> getDeclNumListsRef() const {
- return static_cast<const T *>(this)->template getTrailingObjects<unsigned>(
- NumUniqueDeclarations);
+ return static_cast<const T *>(this)
+ ->template getTrailingObjectsNonStrict<unsigned>(NumUniqueDeclarations);
}
/// Set the number of lists per declaration that are in the trailing
@@ -5966,7 +5969,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// objects of the class. They are appended after the number of lists.
MutableArrayRef<unsigned> getComponentListSizesRef() {
return MutableArrayRef<unsigned>(
- static_cast<T *>(this)->template getTrailingObjects<unsigned>() +
+ static_cast<T *>(this)
+ ->template getTrailingObjectsNonStrict<unsigned>() +
NumUniqueDeclarations,
NumComponentLists);
}
@@ -5975,7 +5979,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// objects of the class. They are appended after the number of lists.
ArrayRef<unsigned> getComponentListSizesRef() const {
return ArrayRef<unsigned>(
- static_cast<const T *>(this)->template getTrailingObjects<unsigned>() +
+ static_cast<const T *>(this)
+ ->template getTrailingObjectsNonStrict<unsigned>() +
NumUniqueDeclarations,
NumComponentLists);
}
@@ -5991,13 +5996,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the components that are in the trailing objects of the class.
MutableArrayRef<MappableComponent> getComponentsRef() {
return static_cast<T *>(this)
- ->template getTrailingObjects<MappableComponent>(NumComponents);
+ ->template getTrailingObjectsNonStrict<MappableComponent>(
+ NumComponents);
}
/// Get the components that are in the trailing objects of the class.
ArrayRef<MappableComponent> getComponentsRef() const {
return static_cast<const T *>(this)
- ->template getTrailingObjects<MappableComponent>(NumComponents);
+ ->template getTrailingObjectsNonStrict<MappableComponent>(
+ NumComponents);
}
/// Set the components that are in the trailing objects of the class.
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 149b274f36b63..642867c0942b5 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -2020,7 +2020,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
#define ABSTRACT_STMT(x)
#define CASTEXPR(Type, Base) \
case Stmt::Type##Class: \
- return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>();
+ return static_cast<Type *>(this) \
+ ->getTrailingObjectsNonStrict<CXXBaseSpecifier *>();
#define STMT(Type, Base)
#include "clang/AST/StmtNodes.inc"
default:
diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h
index f25f2311a81a4..d7211a930ae49 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -228,12 +228,18 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
using ParentType::getTrailingObjectsImpl;
- // This function contains only a static_assert BaseTy is final. The
- // static_assert must be in a function, and not at class-level
- // because BaseTy isn't complete at class instantiation time, but
- // will be by the time this function is instantiated.
- static void verifyTrailingObjectsAssertions() {
+ template <bool Strict> static void verifyTrailingObjectsAssertions() {
+ // The static_assert for BaseTy must be in a function, and not at
+ // class-level because BaseTy isn't complete at class instantiation time,
+ // but will be by the time this function is instantiated.
static_assert(std::is_final<BaseTy>(), "BaseTy must be final.");
+
+ // Verify that templated getTrailingObjects() is used only with multiple
+ // trailing types. Use getTrailingObjectsNonStrict() which does not check
+ // this.
+ static_assert(!Strict || sizeof...(TrailingTys) > 1,
+ "Use templated getTrailingObjects() only when there are "
+ "multiple trailing types");
}
// These two methods are the base of the recursion for this method.
@@ -283,7 +289,7 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
/// (which must be one of those specified in the class template). The
/// array may have zero or more elements in it.
template <typename T> const T *getTrailingObjects() const {
- verifyTrailingObjectsAssertions();
+ verifyTrailingObjectsAssertions<true>();
// Forwards to an impl function with overloads, since member
// function templates can't be specialized.
return this->getTrailingObjectsImpl(
@@ -295,7 +301,7 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
/// (which must be one of those specified in the class template). The
/// array may have zero or more elements in it.
template <typename T> T *getTrailingObjects() {
- verifyTrailingObjectsAssertions();
+ verifyTrailingObjectsAssertions<true>();
// Forwards to an impl function with overloads, since member
// function templates can't be specialized.
return this->getTrailingObjectsImpl(
@@ -310,14 +316,20 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
static_assert(sizeof...(TrailingTys) == 1,
"Can use non-templated getTrailingObjects() only when there "
"is a single trailing type");
- return getTrailingObjects<FirstTrailingType>();
+ verifyTrailingObjectsAssertions<false>();
+ return this->getTrailingObjectsImpl(
+ static_cast<const BaseTy *>(this),
+ TrailingObjectsBase::OverloadToken<FirstTrailingType>());
}
FirstTrailingType *getTrailingObjects() {
static_assert(sizeof...(TrailingTys) == 1,
"Can use non-templated getTrailingObjects() only when there "
"is a single trailing type");
- return getTrailingObjects<FirstTrailingType>();
+ verifyTrailingObjectsAssertions<false>();
+ return this->getTrailingObjectsImpl(
+ static_cast<BaseTy *>(this),
+ TrailingObjectsBase::OverloadToken<FirstTrailingType>());
}
// Functions that return the trailing objects as ArrayRefs.
@@ -337,6 +349,31 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
return ArrayRef(getTrailingObjects(), N);
}
+ // Non-strict forms of templated `getTrailingObjects` that work with single
+ // trailing type.
+ template <typename T> const T *getTrailingObjectsNonStrict() const {
+ verifyTrailingObjectsAssertions<false>();
+ return this->getTrailingObjectsImpl(
+ static_cast<const BaseTy *>(this),
+ TrailingObjectsBase::OverloadToken<T>());
+ }
+
+ template <typename T> T *getTrailingObjectsNonStrict() {
+ verifyTrailingObjectsAssertions<false>();
+ return this->getTrailingObjectsImpl(
+ static_cast<BaseTy *>(this), TrailingObjectsBase::OverloadToken<T>());
+ }
+
+ template <typename T>
+ MutableArrayRef<T> getTrailingObjectsNonStrict(size_t N) {
+ return MutableArrayRef(getTrailingObjectsNonStrict<T>(), N);
+ }
+
+ template <typename T>
+ ArrayRef<T> getTrailingObjectsNonStrict(size_t N) const {
+ return ArrayRef(getTrailingObjectsNonStrict<T>(), N);
+ }
+
/// Returns the size of the trailing data, if an object were
/// allocated with the given counts (The counts are in the same order
/// as the template arguments). This does not include the size of the
diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp
index 2590f375b6598..9184a4dd0cc23 100644
--- a/llvm/unittests/Support/TrailingObjectsTest.cpp
+++ b/llvm/unittests/Support/TrailingObjectsTest.cpp
@@ -45,9 +45,10 @@ class Class1 final : private TrailingObjects<Class1, short> {
template <typename... Ty>
using FixedSizeStorage = TrailingObjects::FixedSizeStorage<Ty...>;
- using TrailingObjects::totalSizeToAlloc;
using TrailingObjects::additionalSizeToAlloc;
using TrailingObjects::getTrailingObjects;
+ using TrailingObjects::getTrailingObjectsNonStrict;
+ using TrailingObjects::totalSizeToAlloc;
};
// Here, there are two singular optional object types appended. Note
@@ -123,11 +124,11 @@ TEST(TrailingObjects, OneArg) {
EXPECT_EQ(Class1::totalSizeToAlloc<short>(3),
sizeof(Class1) + sizeof(short) * 3);
- EXPECT_EQ(C->getTrailingObjects<short>(), reinterpret_cast<short *>(C + 1));
+ EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast<short *>(C + 1));
EXPECT_EQ(C->get(0), 1);
EXPECT_EQ(C->get(2), 3);
- EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects<short>());
+ EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjectsNonStrict<short>());
delete C;
}
More information about the llvm-commits
mailing list