[llvm] [ADT] Fix llvm::concat_iterator for `ValueT == common_base_class *` (PR #144744)

Javier Lopez-Gomez via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 17 07:48:41 PDT 2025


https://github.com/jalopezg-git updated https://github.com/llvm/llvm-project/pull/144744

>From 043b63a373e3b8d9e50530a37434f4987eb23c75 Mon Sep 17 00:00:00 2001
From: Javier Lopez-Gomez <javier.lopez.gomez at proton.me>
Date: Wed, 17 Sep 2025 16:48:27 +0200
Subject: [PATCH 1/2] [ADT] Simplify `llvm::concat_iterator` implementation

Simplify `llvm::concat_iterator::increment()` and `llvm::concat_iterator::get()`.
This makes it possible to also get rid of `handle_type`, which was causing
some additional complications.
---
 llvm/include/llvm/ADT/STLExtras.h | 64 ++++++++++---------------------
 1 file changed, 20 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 02705fbc8f2c4..411520cd228e5 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1000,10 +1000,6 @@ class concat_iterator
   using reference_type =
       typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
 
-  using handle_type =
-      typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
-                                  ValueT *>;
-
   /// We store both the current and end iterators for each concatenated
   /// sequence in a tuple of pairs.
   ///
@@ -1013,49 +1009,38 @@ class concat_iterator
   std::tuple<IterTs...> Begins;
   std::tuple<IterTs...> Ends;
 
-  /// Attempts to increment a specific iterator.
-  ///
-  /// Returns true if it was able to increment the iterator. Returns false if
-  /// the iterator is already at the end iterator.
-  template <size_t Index> bool incrementHelper() {
+  /// Attempts to increment the `Index`-th iterator. If the iterator is already
+  /// at end, recurse over iterators in `Others...`.
+  template <size_t Index, size_t... Others> void incrementImpl() {
     auto &Begin = std::get<Index>(Begins);
     auto &End = std::get<Index>(Ends);
-    if (Begin == End)
-      return false;
-
+    if (Begin == End) {
+      if constexpr (sizeof...(Others) != 0)
+        return incrementImpl<Others...>();
+      llvm_unreachable("Attempted to increment an end concat iterator!");
+    }
     ++Begin;
-    return true;
   }
 
   /// Increments the first non-end iterator.
   ///
   /// It is an error to call this with all iterators at the end.
   template <size_t... Ns> void increment(std::index_sequence<Ns...>) {
-    // Build a sequence of functions to increment each iterator if possible.
-    bool (concat_iterator::*IncrementHelperFns[])() = {
-        &concat_iterator::incrementHelper<Ns>...};
-
-    // Loop over them, and stop as soon as we succeed at incrementing one.
-    for (auto &IncrementHelperFn : IncrementHelperFns)
-      if ((this->*IncrementHelperFn)())
-        return;
-
-    llvm_unreachable("Attempted to increment an end concat iterator!");
+    incrementImpl<Ns...>();
   }
 
-  /// Returns null if the specified iterator is at the end. Otherwise,
-  /// dereferences the iterator and returns the address of the resulting
-  /// reference.
-  template <size_t Index> handle_type getHelper() const {
+  /// Dereferences the `Index`-th iterator and returns the resulting reference.
+  /// If `Index` is at end, recurse over iterators in `Others...`.
+  template <size_t Index, size_t... Others> handle_type getImpl() const {
     auto &Begin = std::get<Index>(Begins);
     auto &End = std::get<Index>(Ends);
-    if (Begin == End)
-      return {};
-
-    if constexpr (ReturnsByValue)
-      return *Begin;
-    else
-      return &*Begin;
+    if (Begin == End) {
+      if constexpr (sizeof...(Others) != 0)
+        return getImpl<Others...>();
+      llvm_unreachable(
+          "Attempted to get a pointer from an end concat iterator!");
+    }
+    return *Begin;
   }
 
   /// Finds the first non-end iterator, dereferences, and returns the resulting
@@ -1063,16 +1048,7 @@ class concat_iterator
   ///
   /// It is an error to call this with all iterators at the end.
   template <size_t... Ns> reference_type get(std::index_sequence<Ns...>) const {
-    // Build a sequence of functions to get from iterator if possible.
-    handle_type (concat_iterator::*GetHelperFns[])()
-        const = {&concat_iterator::getHelper<Ns>...};
-
-    // Loop over them, and return the first result we find.
-    for (auto &GetHelperFn : GetHelperFns)
-      if (auto P = (this->*GetHelperFn)())
-        return *P;
-
-    llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
+    return getImpl<Ns...>();
   }
 
 public:

>From 73384b5e91ad336410a829f594928d2b8ecf20af Mon Sep 17 00:00:00 2001
From: Javier Lopez-Gomez <javier.lopez.gomez at proton.me>
Date: Wed, 17 Sep 2025 16:48:27 +0200
Subject: [PATCH 2/2] [ADT] Fix llvm::concat_iterator for `ValueT ==
 common_base_class *`

Fix llvm::concat_iterator for the case of `ValueT` being a pointer
to a common base class to which the result of dereferencing any
iterator in `ItersT` can be casted to.
---
 llvm/include/llvm/ADT/STLExtras.h    | 21 ++++++++++--
 llvm/unittests/ADT/STLExtrasTest.cpp | 48 ++++++++++++++++++++++++++++
 2 files changed, 66 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 411520cd228e5..04e7717bc0b6b 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -114,6 +114,13 @@ using is_one_of = std::disjunction<std::is_same<T, Ts>...>;
 template <typename T, typename... Ts>
 using are_base_of = std::conjunction<std::is_base_of<T, Ts>...>;
 
+/// traits class for checking whether type `T` is same as all other types in
+/// `Ts`.
+template <typename T = void, typename... Ts>
+using all_types_equal = std::conjunction<std::is_same<T, Ts>...>;
+template <typename T = void, typename... Ts>
+constexpr bool all_types_equal_v = all_types_equal<T, Ts...>::value;
+
 /// Determine if all types in Ts are distinct.
 ///
 /// Useful to statically assert when Ts is intended to describe a non-multi set
@@ -996,9 +1003,17 @@ class concat_iterator
 
   static constexpr bool ReturnsByValue =
       !(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
-
+  static constexpr bool ReturnsConvertibleType =
+      !all_types_equal_v<std::remove_cv_t<ValueT>,
+                         std::remove_cv_t<std::remove_reference_t<
+                             decltype(*std::declval<IterTs>())>>...> &&
+      (std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);
+
+  // Cannot return a reference type if a conversion takes place, provided that
+  // the result of dereferencing all `IterTs...` is convertible to `ValueT`.
   using reference_type =
-      typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
+      typename std::conditional_t<ReturnsByValue || ReturnsConvertibleType,
+                                  ValueT, ValueT &>;
 
   /// We store both the current and end iterators for each concatenated
   /// sequence in a tuple of pairs.
@@ -1031,7 +1046,7 @@ class concat_iterator
 
   /// Dereferences the `Index`-th iterator and returns the resulting reference.
   /// If `Index` is at end, recurse over iterators in `Others...`.
-  template <size_t Index, size_t... Others> handle_type getImpl() const {
+  template <size_t Index, size_t... Others> reference_type getImpl() const {
     auto &Begin = std::get<Index>(Begins);
     auto &End = std::get<Index>(Ends);
     if (Begin == End) {
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index d355a59973f9a..5020acda95b0b 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -398,6 +398,8 @@ struct some_struct {
   std::string swap_val;
 };
 
+struct derives_from_some_struct : some_struct {};
+
 std::vector<int>::const_iterator begin(const some_struct &s) {
   return s.data.begin();
 }
@@ -500,6 +502,15 @@ TEST(STLExtrasTest, ToVector) {
   }
 }
 
+TEST(STLExtrasTest, AllTypesEqual) {
+  static_assert(all_types_equal_v<>);
+  static_assert(all_types_equal_v<int>);
+  static_assert(all_types_equal_v<int, int, int>);
+
+  static_assert(!all_types_equal_v<int, int, unsigned int>);
+  static_assert(!all_types_equal_v<int, int, float>);
+}
+
 TEST(STLExtrasTest, ConcatRange) {
   std::vector<int> Expected = {1, 2, 3, 4, 5, 6, 7, 8};
   std::vector<int> Test;
@@ -532,6 +543,43 @@ TEST(STLExtrasTest, ConcatRangeADL) {
   EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
 }
 
+TEST(STLExtrasTest, ConcatRangePtrToSameClass) {
+  some_namespace::some_struct S0{};
+  some_namespace::some_struct S1{};
+  SmallVector<some_namespace::some_struct *> V0{&S0};
+  SmallVector<some_namespace::some_struct *> V1{&S1, &S1};
+
+  // Dereferencing all iterators yields `some_namespace::some_struct *&`; no
+  // conversion takes place, `reference_type` is
+  // `some_namespace::some_struct *&`.
+  auto C = concat<some_namespace::some_struct *>(V0, V1);
+  static_assert(
+      std::is_same_v<decltype(*C.begin()), some_namespace::some_struct *&>);
+  EXPECT_THAT(C, ElementsAre(&S0, &S1, &S1));
+  // `reference_type` should still allow container modification.
+  for (auto &i : C)
+    if (i == &S0)
+      i = nullptr;
+  EXPECT_THAT(C, ElementsAre(nullptr, &S1, &S1));
+}
+
+TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
+  some_namespace::some_struct S0{};
+  some_namespace::derives_from_some_struct S1{};
+  SmallVector<some_namespace::some_struct *> V0{&S0};
+  SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};
+
+  // Dereferencing all iterators yields different (but convertible types);
+  // conversion takes place, `reference_type` is
+  // `some_namespace::some_struct *`.
+  auto C = concat<some_namespace::some_struct *>(V0, V1);
+  static_assert(
+      std::is_same_v<decltype(*C.begin()), some_namespace::some_struct *>);
+  EXPECT_THAT(C,
+              ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
+                          static_cast<some_namespace::some_struct *>(&S1)));
+}
+
 TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
   // Make sure that we use the `begin`/`end` functions from `some_namespace`,
   // using ADL.



More information about the llvm-commits mailing list