[clang] Make getAttr, getSpecificAttr, hasAttr, and hasSpecificAttr variadic (PR #78518)

Erich Keane via cfe-commits cfe-commits at lists.llvm.org
Thu Jan 18 08:13:17 PST 2024


https://github.com/erichkeane updated https://github.com/llvm/llvm-project/pull/78518

>From e3d667796b903d35405014db67b130ae20816bab Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Wed, 17 Jan 2024 10:26:22 -0800
Subject: [PATCH 1/2] Make getAttr, getSpecificAttr, hasAttr, and
 hasSpecificAttr variadic

Discussed elsewhere, it makes sense to have these functions be
variadic. They work as an 'or', that is, 'has*Attr' is "one of the set",
and "get*Attr" means "first of the set".

This is accomplished by extracting specific_attr_iterator into a _impl
version, and creating a version + partial specialization for the
_iterator version that inherits from the _impl.  The result is that all
of the 'existing' code 'just works'.

One exception (see the change in SemaOverload.cpp) could be solved with
a conversion function, but instead I put 'auto' in.  Else, I could
specify the 'AttrVec' if it were preferential.
---
 clang/include/clang/AST/AttrIterator.h | 108 ++++++++++++++++---------
 clang/include/clang/AST/DeclBase.h     |  28 ++++---
 clang/lib/Sema/SemaOverload.cpp        |   9 +--
 3 files changed, 90 insertions(+), 55 deletions(-)

diff --git a/clang/include/clang/AST/AttrIterator.h b/clang/include/clang/AST/AttrIterator.h
index 66571e1cf0b8ec..51732d9a3334b2 100644
--- a/clang/include/clang/AST/AttrIterator.h
+++ b/clang/include/clang/AST/AttrIterator.h
@@ -27,12 +27,24 @@ class Attr;
 /// AttrVec - A vector of Attr, which is how they are stored on the AST.
 using AttrVec = SmallVector<Attr *, 4>;
 
-/// specific_attr_iterator - Iterates over a subrange of an AttrVec, only
-/// providing attributes that are of a specific type.
-template <typename SpecificAttr, typename Container = AttrVec>
-class specific_attr_iterator {
+/// Iterates over a subrange of container, only providing attributes that are of
+/// a specific type/s.
+template <typename Container, typename... SpecificAttrs>
+class specific_attr_iterator_impl {
   using Iterator = typename Container::const_iterator;
 
+  /// Helper class to get either the singular 'specific-attr', or Attr,
+  /// depending on how many are specified.
+  template <typename... Ts> struct type_helper {
+    using type = Attr;
+  };
+  template <typename T> struct type_helper<T> {
+    using type = T;
+  };
+
+  /// The pointee type of the value_type, used for internal implementation.
+  using base_type = typename type_helper<SpecificAttrs...>::type;
+
   /// Current - The current, underlying iterator.
   /// In order to ensure we don't dereference an invalid iterator unless
   /// specifically requested, we don't necessarily advance this all the
@@ -43,46 +55,46 @@ class specific_attr_iterator {
   mutable Iterator Current;
 
   void AdvanceToNext() const {
-    while (!isa<SpecificAttr>(*Current))
+    while (!isa<SpecificAttrs...>(*Current))
       ++Current;
   }
 
   void AdvanceToNext(Iterator I) const {
-    while (Current != I && !isa<SpecificAttr>(*Current))
+    while (Current != I && !isa<SpecificAttrs...>(*Current))
       ++Current;
   }
 
 public:
-  using value_type = SpecificAttr *;
-  using reference = SpecificAttr *;
-  using pointer = SpecificAttr *;
+  using value_type = base_type *;
+  using reference = value_type;
+  using pointer = value_type;
   using iterator_category = std::forward_iterator_tag;
   using difference_type = std::ptrdiff_t;
 
-  specific_attr_iterator() = default;
-  explicit specific_attr_iterator(Iterator i) : Current(i) {}
+  specific_attr_iterator_impl() = default;
+  explicit specific_attr_iterator_impl(Iterator i) : Current(i) {}
 
   reference operator*() const {
     AdvanceToNext();
-    return cast<SpecificAttr>(*Current);
+    return cast<base_type>(*Current);
   }
   pointer operator->() const {
     AdvanceToNext();
-    return cast<SpecificAttr>(*Current);
+    return cast<base_type>(*Current);
   }
 
-  specific_attr_iterator& operator++() {
+  specific_attr_iterator_impl &operator++() {
     ++Current;
     return *this;
   }
-  specific_attr_iterator operator++(int) {
-    specific_attr_iterator Tmp(*this);
+  specific_attr_iterator_impl operator++(int) {
+    specific_attr_iterator_impl Tmp(*this);
     ++(*this);
     return Tmp;
   }
 
-  friend bool operator==(specific_attr_iterator Left,
-                         specific_attr_iterator Right) {
+  friend bool operator==(specific_attr_iterator_impl Left,
+                         specific_attr_iterator_impl Right) {
     assert((Left.Current == nullptr) == (Right.Current == nullptr));
     if (Left.Current < Right.Current)
       Left.AdvanceToNext(Right.Current);
@@ -90,33 +102,55 @@ class specific_attr_iterator {
       Right.AdvanceToNext(Left.Current);
     return Left.Current == Right.Current;
   }
-  friend bool operator!=(specific_attr_iterator Left,
-                         specific_attr_iterator Right) {
+  friend bool operator!=(specific_attr_iterator_impl Left,
+                         specific_attr_iterator_impl Right) {
     return !(Left == Right);
   }
 };
 
-template <typename SpecificAttr, typename Container>
-inline specific_attr_iterator<SpecificAttr, Container>
-          specific_attr_begin(const Container& container) {
-  return specific_attr_iterator<SpecificAttr, Container>(container.begin());
+/// Iterates over a subrange of a collection, only providing attributes that are
+/// of a specific type/s.
+template <typename Container, typename... SpecificAttrs>
+class specific_attr_iterator;
+
+template <typename SpecificAttr>
+class specific_attr_iterator<SpecificAttr>
+    : public specific_attr_iterator_impl<AttrVec, SpecificAttr> {
+  using specific_attr_iterator_impl<AttrVec,
+                                    SpecificAttr>::specific_attr_iterator_impl;
+};
+
+template <typename Container, typename... SpecificAttrs>
+class specific_attr_iterator
+    : public specific_attr_iterator_impl<Container, SpecificAttrs...> {
+  using specific_attr_iterator_impl<
+      Container, SpecificAttrs...>::specific_attr_iterator_impl;
+};
+
+template <typename... SpecificAttrs, typename Container>
+inline specific_attr_iterator<Container, SpecificAttrs...>
+specific_attr_begin(const Container &container) {
+  return specific_attr_iterator<Container, SpecificAttrs...>(container.begin());
 }
-template <typename SpecificAttr, typename Container>
-inline specific_attr_iterator<SpecificAttr, Container>
-          specific_attr_end(const Container& container) {
-  return specific_attr_iterator<SpecificAttr, Container>(container.end());
+
+template <typename... SpecificAttrs, typename Container>
+inline specific_attr_iterator<Container, SpecificAttrs...>
+specific_attr_end(const Container &container) {
+  return specific_attr_iterator<Container, SpecificAttrs...>(container.end());
 }
 
-template <typename SpecificAttr, typename Container>
-inline bool hasSpecificAttr(const Container& container) {
-  return specific_attr_begin<SpecificAttr>(container) !=
-          specific_attr_end<SpecificAttr>(container);
+template <typename... SpecificAttrs, typename Container>
+inline bool hasSpecificAttr(const Container &container) {
+  return specific_attr_begin<SpecificAttrs...>(container) !=
+         specific_attr_end<SpecificAttrs...>(container);
 }
-template <typename SpecificAttr, typename Container>
-inline SpecificAttr *getSpecificAttr(const Container& container) {
-  specific_attr_iterator<SpecificAttr, Container> i =
-      specific_attr_begin<SpecificAttr>(container);
-  if (i != specific_attr_end<SpecificAttr>(container))
+
+template <typename... SpecificAttrs, typename Container>
+inline typename specific_attr_iterator_impl<Container,
+                                            SpecificAttrs...>::value_type
+getSpecificAttr(const Container &container) {
+  auto i = specific_attr_begin<SpecificAttrs...>(container);
+  if (i != specific_attr_end<SpecificAttrs...>(container))
     return *i;
   else
     return nullptr;
diff --git a/clang/include/clang/AST/DeclBase.h b/clang/include/clang/AST/DeclBase.h
index d957ea24f6394a..cb111af9b17fc3 100644
--- a/clang/include/clang/AST/DeclBase.h
+++ b/clang/include/clang/AST/DeclBase.h
@@ -560,27 +560,29 @@ class alignas(8) Decl {
 
   template <typename T> void dropAttr() { dropAttrs<T>(); }
 
-  template <typename T>
-  llvm::iterator_range<specific_attr_iterator<T>> specific_attrs() const {
-    return llvm::make_range(specific_attr_begin<T>(), specific_attr_end<T>());
+  template <typename... Ts>
+  llvm::iterator_range<specific_attr_iterator<AttrVec, Ts...>>
+  specific_attrs() const {
+    return llvm::make_range(specific_attr_begin<Ts...>(),
+                            specific_attr_end<Ts...>());
   }
 
-  template <typename T>
-  specific_attr_iterator<T> specific_attr_begin() const {
-    return specific_attr_iterator<T>(attr_begin());
+  template <typename... Ts>
+  specific_attr_iterator<AttrVec, Ts...> specific_attr_begin() const {
+    return specific_attr_iterator<AttrVec, Ts...>(attr_begin());
   }
 
-  template <typename T>
-  specific_attr_iterator<T> specific_attr_end() const {
-    return specific_attr_iterator<T>(attr_end());
+  template <typename... Ts>
+  specific_attr_iterator<AttrVec, Ts...> specific_attr_end() const {
+    return specific_attr_iterator<AttrVec, Ts...>(attr_end());
   }
 
-  template<typename T> T *getAttr() const {
-    return hasAttrs() ? getSpecificAttr<T>(getAttrs()) : nullptr;
+  template <typename... Ts> auto *getAttr() const {
+    return hasAttrs() ? getSpecificAttr<Ts...>(getAttrs()) : nullptr;
   }
 
-  template<typename T> bool hasAttr() const {
-    return hasAttrs() && hasSpecificAttr<T>(getAttrs());
+  template <typename... Ts> bool hasAttr() const {
+    return hasAttrs() && hasSpecificAttr<Ts...>(getAttrs());
   }
 
   /// getMaxAlignment - return the maximum alignment specified by attributes
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 37c62b306b3cd3..c4c713efa7f5ab 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -1478,11 +1478,10 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
     return true;
 
   // enable_if attributes are an order-sensitive part of the signature.
-  for (specific_attr_iterator<EnableIfAttr>
-         NewI = New->specific_attr_begin<EnableIfAttr>(),
-         NewE = New->specific_attr_end<EnableIfAttr>(),
-         OldI = Old->specific_attr_begin<EnableIfAttr>(),
-         OldE = Old->specific_attr_end<EnableIfAttr>();
+  for (auto NewI = New->specific_attr_begin<EnableIfAttr>(),
+            NewE = New->specific_attr_end<EnableIfAttr>(),
+            OldI = Old->specific_attr_begin<EnableIfAttr>(),
+            OldE = Old->specific_attr_end<EnableIfAttr>();
        NewI != NewE || OldI != OldE; ++NewI, ++OldI) {
     if (NewI == NewE || OldI == OldE)
       return true;

>From 1958942585cc1792cbc58d159bcb1d02b9ed00b8 Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Thu, 18 Jan 2024 07:35:56 -0800
Subject: [PATCH 2/2] MOdify uses in SemaDecl.cpp to use variadic hasAttr.

---
 clang/lib/Sema/SemaDecl.cpp | 21 ++++++++-------------
 1 file changed, 8 insertions(+), 13 deletions(-)

diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 5472b43aafd4f3..88f71e26c9abbf 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2026,8 +2026,7 @@ static bool ShouldDiagnoseUnusedDecl(const LangOptions &LangOpts,
   if (D->isPlaceholderVar(LangOpts))
     return false;
 
-  if (D->hasAttr<UnusedAttr>() || D->hasAttr<ObjCPreciseLifetimeAttr>() ||
-      D->hasAttr<CleanupAttr>())
+  if (D->hasAttr<UnusedAttr, ObjCPreciseLifetimeAttr, CleanupAttr>())
     return false;
 
   if (isa<LabelDecl>(D))
@@ -7331,8 +7330,8 @@ static bool isIncompleteDeclExternC(Sema &S, const T *D) {
       return false;
 
     // So do CUDA's host/device attributes.
-    if (S.getLangOpts().CUDA && (D->template hasAttr<CUDADeviceAttr>() ||
-                                 D->template hasAttr<CUDAHostAttr>()))
+    if (S.getLangOpts().CUDA &&
+        (D->template hasAttr<CUDADeviceAttr, CUDAHostAttr>()))
       return false;
   }
   return D->isExternC();
@@ -8035,8 +8034,7 @@ NamedDecl *Sema::ActOnVariableDeclarator(
     // CUDA B.2.5: "__shared__ and __constant__ variables have implied static
     // storage [duration]."
     if (SC == SC_None && S->getFnParent() != nullptr &&
-        (NewVD->hasAttr<CUDASharedAttr>() ||
-         NewVD->hasAttr<CUDAConstantAttr>())) {
+        (NewVD->hasAttr<CUDASharedAttr, CUDAConstantAttr>())) {
       NewVD->setStorageClass(SC_Static);
     }
   }
@@ -8805,8 +8803,7 @@ void Sema::CheckVariableDeclarationType(VarDecl *NewVD) {
   }
 
   bool isVM = T->isVariablyModifiedType();
-  if (isVM || NewVD->hasAttr<CleanupAttr>() ||
-      NewVD->hasAttr<BlocksAttr>())
+  if (isVM || NewVD->hasAttr<CleanupAttr, BlocksAttr>())
     setFunctionHasBranchProtectedScope();
 
   if ((isVM && NewVD->hasLinkage()) ||
@@ -10797,8 +10794,7 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
     // in device-side CUDA code, unless someone passed
     // -fcuda-allow-variadic-functions.
     if (!getLangOpts().CUDAAllowVariadicFunctions && NewFD->isVariadic() &&
-        (NewFD->hasAttr<CUDADeviceAttr>() ||
-         NewFD->hasAttr<CUDAGlobalAttr>()) &&
+        (NewFD->hasAttr<CUDADeviceAttr, CUDAGlobalAttr>()) &&
         !(II && II->isStr("printf") && NewFD->isExternC() &&
           !D.isFunctionDefinition())) {
       Diag(NewFD->getLocation(), diag::err_variadic_device_fn);
@@ -14617,8 +14613,7 @@ void Sema::CheckStaticLocalForDllExport(VarDecl *VD) {
 
   // Find outermost function when VD is in lambda function.
   while (FD && !getDLLAttr(FD) &&
-         !FD->hasAttr<DLLExportStaticLocalAttr>() &&
-         !FD->hasAttr<DLLImportStaticLocalAttr>()) {
+         !FD->hasAttr<DLLExportStaticLocalAttr, DLLImportStaticLocalAttr>()) {
     FD = dyn_cast_or_null<FunctionDecl>(FD->getParentFunctionOrMethod());
   }
 
@@ -16675,7 +16670,7 @@ void Sema::AddKnownFunctionAttributes(FunctionDecl *FD) {
     if (Context.BuiltinInfo.isConst(BuiltinID) && !FD->hasAttr<ConstAttr>())
       FD->addAttr(ConstAttr::CreateImplicit(Context, FD->getLocation()));
     if (getLangOpts().CUDA && Context.BuiltinInfo.isTSBuiltin(BuiltinID) &&
-        !FD->hasAttr<CUDADeviceAttr>() && !FD->hasAttr<CUDAHostAttr>()) {
+        !FD->hasAttr<CUDADeviceAttr, CUDAHostAttr>()) {
       // Add the appropriate attribute, depending on the CUDA compilation mode
       // and which target the builtin belongs to. For example, during host
       // compilation, aux builtins are __device__, while the rest are __host__.



More information about the cfe-commits mailing list