[clang] Make getAttr, getSpecificAttr, hasAttr, and hasSpecificAttr variadic (PR #78518)

Erich Keane via cfe-commits cfe-commits at lists.llvm.org
Wed Jan 17 15:04:19 PST 2024


https://github.com/erichkeane created https://github.com/llvm/llvm-project/pull/78518

Discussed elsewhere, it makes sense to have these functions be variadic. They work as an 'or', that is, 'has*Attr' is "one of the set", and "get*Attr" means "first of the set".

This is accomplished by extracting specific_attr_iterator into a _impl version, and creating a version + partial specialization for the _iterator version that inherits from the _impl.  The result is that all of the 'existing' code 'just works'.

One exception (see the change in SemaOverload.cpp) could be solved with a conversion function, but instead I put 'auto' in.  Else, I could specify the 'AttrVec' if it were preferential.

>From e3d667796b903d35405014db67b130ae20816bab Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Wed, 17 Jan 2024 10:26:22 -0800
Subject: [PATCH] Make getAttr, getSpecificAttr, hasAttr, and hasSpecificAttr
 variadic

Discussed elsewhere, it makes sense to have these functions be
variadic. They work as an 'or', that is, 'has*Attr' is "one of the set",
and "get*Attr" means "first of the set".

This is accomplished by extracting specific_attr_iterator into a _impl
version, and creating a version + partial specialization for the
_iterator version that inherits from the _impl.  The result is that all
of the 'existing' code 'just works'.

One exception (see the change in SemaOverload.cpp) could be solved with
a conversion function, but instead I put 'auto' in.  Else, I could
specify the 'AttrVec' if it were preferential.
---
 clang/include/clang/AST/AttrIterator.h | 108 ++++++++++++++++---------
 clang/include/clang/AST/DeclBase.h     |  28 ++++---
 clang/lib/Sema/SemaOverload.cpp        |   9 +--
 3 files changed, 90 insertions(+), 55 deletions(-)

diff --git a/clang/include/clang/AST/AttrIterator.h b/clang/include/clang/AST/AttrIterator.h
index 66571e1cf0b8ec4..51732d9a3334b2e 100644
--- a/clang/include/clang/AST/AttrIterator.h
+++ b/clang/include/clang/AST/AttrIterator.h
@@ -27,12 +27,24 @@ class Attr;
 /// AttrVec - A vector of Attr, which is how they are stored on the AST.
 using AttrVec = SmallVector<Attr *, 4>;
 
-/// specific_attr_iterator - Iterates over a subrange of an AttrVec, only
-/// providing attributes that are of a specific type.
-template <typename SpecificAttr, typename Container = AttrVec>
-class specific_attr_iterator {
+/// Iterates over a subrange of container, only providing attributes that are of
+/// a specific type/s.
+template <typename Container, typename... SpecificAttrs>
+class specific_attr_iterator_impl {
   using Iterator = typename Container::const_iterator;
 
+  /// Helper class to get either the singular 'specific-attr', or Attr,
+  /// depending on how many are specified.
+  template <typename... Ts> struct type_helper {
+    using type = Attr;
+  };
+  template <typename T> struct type_helper<T> {
+    using type = T;
+  };
+
+  /// The pointee type of the value_type, used for internal implementation.
+  using base_type = typename type_helper<SpecificAttrs...>::type;
+
   /// Current - The current, underlying iterator.
   /// In order to ensure we don't dereference an invalid iterator unless
   /// specifically requested, we don't necessarily advance this all the
@@ -43,46 +55,46 @@ class specific_attr_iterator {
   mutable Iterator Current;
 
   void AdvanceToNext() const {
-    while (!isa<SpecificAttr>(*Current))
+    while (!isa<SpecificAttrs...>(*Current))
       ++Current;
   }
 
   void AdvanceToNext(Iterator I) const {
-    while (Current != I && !isa<SpecificAttr>(*Current))
+    while (Current != I && !isa<SpecificAttrs...>(*Current))
       ++Current;
   }
 
 public:
-  using value_type = SpecificAttr *;
-  using reference = SpecificAttr *;
-  using pointer = SpecificAttr *;
+  using value_type = base_type *;
+  using reference = value_type;
+  using pointer = value_type;
   using iterator_category = std::forward_iterator_tag;
   using difference_type = std::ptrdiff_t;
 
-  specific_attr_iterator() = default;
-  explicit specific_attr_iterator(Iterator i) : Current(i) {}
+  specific_attr_iterator_impl() = default;
+  explicit specific_attr_iterator_impl(Iterator i) : Current(i) {}
 
   reference operator*() const {
     AdvanceToNext();
-    return cast<SpecificAttr>(*Current);
+    return cast<base_type>(*Current);
   }
   pointer operator->() const {
     AdvanceToNext();
-    return cast<SpecificAttr>(*Current);
+    return cast<base_type>(*Current);
   }
 
-  specific_attr_iterator& operator++() {
+  specific_attr_iterator_impl &operator++() {
     ++Current;
     return *this;
   }
-  specific_attr_iterator operator++(int) {
-    specific_attr_iterator Tmp(*this);
+  specific_attr_iterator_impl operator++(int) {
+    specific_attr_iterator_impl Tmp(*this);
     ++(*this);
     return Tmp;
   }
 
-  friend bool operator==(specific_attr_iterator Left,
-                         specific_attr_iterator Right) {
+  friend bool operator==(specific_attr_iterator_impl Left,
+                         specific_attr_iterator_impl Right) {
     assert((Left.Current == nullptr) == (Right.Current == nullptr));
     if (Left.Current < Right.Current)
       Left.AdvanceToNext(Right.Current);
@@ -90,33 +102,55 @@ class specific_attr_iterator {
       Right.AdvanceToNext(Left.Current);
     return Left.Current == Right.Current;
   }
-  friend bool operator!=(specific_attr_iterator Left,
-                         specific_attr_iterator Right) {
+  friend bool operator!=(specific_attr_iterator_impl Left,
+                         specific_attr_iterator_impl Right) {
     return !(Left == Right);
   }
 };
 
-template <typename SpecificAttr, typename Container>
-inline specific_attr_iterator<SpecificAttr, Container>
-          specific_attr_begin(const Container& container) {
-  return specific_attr_iterator<SpecificAttr, Container>(container.begin());
+/// Iterates over a subrange of a collection, only providing attributes that are
+/// of a specific type/s.
+template <typename Container, typename... SpecificAttrs>
+class specific_attr_iterator;
+
+template <typename SpecificAttr>
+class specific_attr_iterator<SpecificAttr>
+    : public specific_attr_iterator_impl<AttrVec, SpecificAttr> {
+  using specific_attr_iterator_impl<AttrVec,
+                                    SpecificAttr>::specific_attr_iterator_impl;
+};
+
+template <typename Container, typename... SpecificAttrs>
+class specific_attr_iterator
+    : public specific_attr_iterator_impl<Container, SpecificAttrs...> {
+  using specific_attr_iterator_impl<
+      Container, SpecificAttrs...>::specific_attr_iterator_impl;
+};
+
+template <typename... SpecificAttrs, typename Container>
+inline specific_attr_iterator<Container, SpecificAttrs...>
+specific_attr_begin(const Container &container) {
+  return specific_attr_iterator<Container, SpecificAttrs...>(container.begin());
 }
-template <typename SpecificAttr, typename Container>
-inline specific_attr_iterator<SpecificAttr, Container>
-          specific_attr_end(const Container& container) {
-  return specific_attr_iterator<SpecificAttr, Container>(container.end());
+
+template <typename... SpecificAttrs, typename Container>
+inline specific_attr_iterator<Container, SpecificAttrs...>
+specific_attr_end(const Container &container) {
+  return specific_attr_iterator<Container, SpecificAttrs...>(container.end());
 }
 
-template <typename SpecificAttr, typename Container>
-inline bool hasSpecificAttr(const Container& container) {
-  return specific_attr_begin<SpecificAttr>(container) !=
-          specific_attr_end<SpecificAttr>(container);
+template <typename... SpecificAttrs, typename Container>
+inline bool hasSpecificAttr(const Container &container) {
+  return specific_attr_begin<SpecificAttrs...>(container) !=
+         specific_attr_end<SpecificAttrs...>(container);
 }
-template <typename SpecificAttr, typename Container>
-inline SpecificAttr *getSpecificAttr(const Container& container) {
-  specific_attr_iterator<SpecificAttr, Container> i =
-      specific_attr_begin<SpecificAttr>(container);
-  if (i != specific_attr_end<SpecificAttr>(container))
+
+template <typename... SpecificAttrs, typename Container>
+inline typename specific_attr_iterator_impl<Container,
+                                            SpecificAttrs...>::value_type
+getSpecificAttr(const Container &container) {
+  auto i = specific_attr_begin<SpecificAttrs...>(container);
+  if (i != specific_attr_end<SpecificAttrs...>(container))
     return *i;
   else
     return nullptr;
diff --git a/clang/include/clang/AST/DeclBase.h b/clang/include/clang/AST/DeclBase.h
index d957ea24f6394a1..cb111af9b17fc39 100644
--- a/clang/include/clang/AST/DeclBase.h
+++ b/clang/include/clang/AST/DeclBase.h
@@ -560,27 +560,29 @@ class alignas(8) Decl {
 
   template <typename T> void dropAttr() { dropAttrs<T>(); }
 
-  template <typename T>
-  llvm::iterator_range<specific_attr_iterator<T>> specific_attrs() const {
-    return llvm::make_range(specific_attr_begin<T>(), specific_attr_end<T>());
+  template <typename... Ts>
+  llvm::iterator_range<specific_attr_iterator<AttrVec, Ts...>>
+  specific_attrs() const {
+    return llvm::make_range(specific_attr_begin<Ts...>(),
+                            specific_attr_end<Ts...>());
   }
 
-  template <typename T>
-  specific_attr_iterator<T> specific_attr_begin() const {
-    return specific_attr_iterator<T>(attr_begin());
+  template <typename... Ts>
+  specific_attr_iterator<AttrVec, Ts...> specific_attr_begin() const {
+    return specific_attr_iterator<AttrVec, Ts...>(attr_begin());
   }
 
-  template <typename T>
-  specific_attr_iterator<T> specific_attr_end() const {
-    return specific_attr_iterator<T>(attr_end());
+  template <typename... Ts>
+  specific_attr_iterator<AttrVec, Ts...> specific_attr_end() const {
+    return specific_attr_iterator<AttrVec, Ts...>(attr_end());
   }
 
-  template<typename T> T *getAttr() const {
-    return hasAttrs() ? getSpecificAttr<T>(getAttrs()) : nullptr;
+  template <typename... Ts> auto *getAttr() const {
+    return hasAttrs() ? getSpecificAttr<Ts...>(getAttrs()) : nullptr;
   }
 
-  template<typename T> bool hasAttr() const {
-    return hasAttrs() && hasSpecificAttr<T>(getAttrs());
+  template <typename... Ts> bool hasAttr() const {
+    return hasAttrs() && hasSpecificAttr<Ts...>(getAttrs());
   }
 
   /// getMaxAlignment - return the maximum alignment specified by attributes
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 37c62b306b3cd3f..c4c713efa7f5ab5 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -1478,11 +1478,10 @@ static bool IsOverloadOrOverrideImpl(Sema &SemaRef, FunctionDecl *New,
     return true;
 
   // enable_if attributes are an order-sensitive part of the signature.
-  for (specific_attr_iterator<EnableIfAttr>
-         NewI = New->specific_attr_begin<EnableIfAttr>(),
-         NewE = New->specific_attr_end<EnableIfAttr>(),
-         OldI = Old->specific_attr_begin<EnableIfAttr>(),
-         OldE = Old->specific_attr_end<EnableIfAttr>();
+  for (auto NewI = New->specific_attr_begin<EnableIfAttr>(),
+            NewE = New->specific_attr_end<EnableIfAttr>(),
+            OldI = Old->specific_attr_begin<EnableIfAttr>(),
+            OldE = Old->specific_attr_end<EnableIfAttr>();
        NewI != NewE || OldI != OldE; ++NewI, ++OldI) {
     if (NewI == NewE || OldI == OldE)
       return true;



More information about the cfe-commits mailing list