[clang] [llvm] [LLVM][Clang] Add and enable strict mode for `getTrailingObjects` (PR #144930)
Rahul Joshi via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 19 10:28:49 PDT 2025
https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/144930
Under strict mode, the templated `getTrailingObjects` can be called only when there is > 1 trailing types. The strict mode can be disabled on a per-call basis when its not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h).
>From 3648af7d6dd8d92f7a6549c101b0710ae014c57d Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Thu, 19 Jun 2025 10:25:12 -0700
Subject: [PATCH] [LLVM][Clang] Add and enable strict mode for
`getTrailingObjects`
Under strict mode, the templated `getTrailingObjects` can be called
only when there is > 1 trailing types. The strict mode can be disabled
on a per-call basis when its not possible to know statically if there
will be a single or multiple trailing types (like in OpenMPClause.h).
---
clang/include/clang/AST/OpenMPClause.h | 46 +++++++++++++--------
clang/lib/AST/Expr.cpp | 3 +-
llvm/include/llvm/Support/TrailingObjects.h | 21 ++++++----
3 files changed, 44 insertions(+), 26 deletions(-)
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 2fa8fa529741e..b62ebd614e4c7 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public OMPClause {
/// Fetches list of variables associated with this clause.
MutableArrayRef<Expr *> getVarRefs() {
- return static_cast<T *>(this)->template getTrailingObjects<Expr *>(NumVars);
+ return static_cast<T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
}
/// Sets the list of variables for this clause.
@@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public OMPClause {
/// Fetches list of all variables in the clause.
ArrayRef<const Expr *> getVarRefs() const {
- return static_cast<const T *>(this)->template getTrailingObjects<Expr *>(
- NumVars);
+ return static_cast<const T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
}
};
@@ -380,7 +381,8 @@ template <class T> class OMPDirectiveListClause : public OMPClause {
MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() {
return static_cast<T *>(this)
- ->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds);
+ ->template getTrailingObjects<OpenMPDirectiveKind, /*Strict=*/false>(
+ NumKinds);
}
void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) {
@@ -5901,15 +5903,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the unique declarations that are in the trailing objects of the
/// class.
MutableArrayRef<ValueDecl *> getUniqueDeclsRef() {
- return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>(
- NumUniqueDeclarations);
+ return static_cast<T *>(this)
+ ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Get the unique declarations that are in the trailing objects of the
/// class.
ArrayRef<ValueDecl *> getUniqueDeclsRef() const {
return static_cast<const T *>(this)
- ->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations);
+ ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Set the unique declarations that are in the trailing objects of the
@@ -5923,15 +5927,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the number of lists per declaration that are in the trailing
/// objects of the class.
MutableArrayRef<unsigned> getDeclNumListsRef() {
- return static_cast<T *>(this)->template getTrailingObjects<unsigned>(
- NumUniqueDeclarations);
+ return static_cast<T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Get the number of lists per declaration that are in the trailing
/// objects of the class.
ArrayRef<unsigned> getDeclNumListsRef() const {
- return static_cast<const T *>(this)->template getTrailingObjects<unsigned>(
- NumUniqueDeclarations);
+ return static_cast<const T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Set the number of lists per declaration that are in the trailing
@@ -5946,7 +5952,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// objects of the class. They are appended after the number of lists.
MutableArrayRef<unsigned> getComponentListSizesRef() {
return MutableArrayRef<unsigned>(
- static_cast<T *>(this)->template getTrailingObjects<unsigned>() +
+ static_cast<T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
NumUniqueDeclarations,
NumComponentLists);
}
@@ -5955,7 +5962,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// objects of the class. They are appended after the number of lists.
ArrayRef<unsigned> getComponentListSizesRef() const {
return ArrayRef<unsigned>(
- static_cast<const T *>(this)->template getTrailingObjects<unsigned>() +
+ static_cast<const T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
NumUniqueDeclarations,
NumComponentLists);
}
@@ -5971,13 +5979,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the components that are in the trailing objects of the class.
MutableArrayRef<MappableComponent> getComponentsRef() {
return static_cast<T *>(this)
- ->template getTrailingObjects<MappableComponent>(NumComponents);
+ ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+ NumComponents);
}
/// Get the components that are in the trailing objects of the class.
ArrayRef<MappableComponent> getComponentsRef() const {
return static_cast<const T *>(this)
- ->template getTrailingObjects<MappableComponent>(NumComponents);
+ ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+ NumComponents);
}
/// Set the components that are in the trailing objects of the class.
@@ -6084,7 +6094,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
assert(SupportsMapper &&
"Must be a clause that is possible to have user-defined mappers");
return llvm::MutableArrayRef<Expr *>(
- static_cast<T *>(this)->template getTrailingObjects<Expr *>() +
+ static_cast<T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
OMPVarListClause<T>::varlist_size(),
OMPVarListClause<T>::varlist_size());
}
@@ -6095,7 +6106,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
assert(SupportsMapper &&
"Must be a clause that is possible to have user-defined mappers");
return llvm::ArrayRef<Expr *>(
- static_cast<const T *>(this)->template getTrailingObjects<Expr *>() +
+ static_cast<const T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
OMPVarListClause<T>::varlist_size(),
OMPVarListClause<T>::varlist_size());
}
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c3722c65abf6e..b93a31ca4ed36 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -2024,7 +2024,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
#define ABSTRACT_STMT(x)
#define CASTEXPR(Type, Base) \
case Stmt::Type##Class: \
- return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>();
+ return static_cast<Type *>(this) \
+ ->getTrailingObjects<CXXBaseSpecifier *, /*Strict=*/false>();
#define STMT(Type, Base)
#include "clang/AST/StmtNodes.inc"
default:
diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h
index f25f2311a81a4..3d701de93b4f1 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -282,7 +282,9 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
/// Returns a pointer to the trailing object array of the given type
/// (which must be one of those specified in the class template). The
/// array may have zero or more elements in it.
- template <typename T> const T *getTrailingObjects() const {
+ template <typename T, bool Strict = true>
+ const T *getTrailingObjects() const {
+ static_assert(!Strict || sizeof...(TrailingTys) > 1);
verifyTrailingObjectsAssertions();
// Forwards to an impl function with overloads, since member
// function templates can't be specialized.
@@ -294,7 +296,8 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
/// Returns a pointer to the trailing object array of the given type
/// (which must be one of those specified in the class template). The
/// array may have zero or more elements in it.
- template <typename T> T *getTrailingObjects() {
+ template <typename T, bool Strict = true> T *getTrailingObjects() {
+ static_assert(!Strict || sizeof...(TrailingTys) > 1);
verifyTrailingObjectsAssertions();
// Forwards to an impl function with overloads, since member
// function templates can't be specialized.
@@ -310,23 +313,25 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
static_assert(sizeof...(TrailingTys) == 1,
"Can use non-templated getTrailingObjects() only when there "
"is a single trailing type");
- return getTrailingObjects<FirstTrailingType>();
+ return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
}
FirstTrailingType *getTrailingObjects() {
static_assert(sizeof...(TrailingTys) == 1,
"Can use non-templated getTrailingObjects() only when there "
"is a single trailing type");
- return getTrailingObjects<FirstTrailingType>();
+ return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
}
// Functions that return the trailing objects as ArrayRefs.
- template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) {
- return MutableArrayRef(getTrailingObjects<T>(), N);
+ template <typename T, bool Strict = true>
+ MutableArrayRef<T> getTrailingObjects(size_t N) {
+ return MutableArrayRef(getTrailingObjects<T, Strict>(), N);
}
- template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const {
- return ArrayRef(getTrailingObjects<T>(), N);
+ template <typename T, bool Strict = true>
+ ArrayRef<T> getTrailingObjects(size_t N) const {
+ return ArrayRef(getTrailingObjects<T, Strict>(), N);
}
MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) {
More information about the llvm-commits
mailing list