[llvm] [ADT] Fix llvm::concat_iterator for `ValueT == common_base_class *` (PR #144744)

Javier Lopez-Gomez via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 18 09:19:42 PDT 2025


https://github.com/jalopezg-git created https://github.com/llvm/llvm-project/pull/144744

Fix `llvm::concat_iterator` for the case of `ValueT` being a pointer to a common base class to which the result of dereferencing any iterator in `ItersT` can be casted to.

In particular, the case below was not working before this patch, but I see no particular reason why it shouldn't be supported.
```
namespace some_namespace {
struct some_struct {
  std::vector<int> data;
  std::string swap_val;
};

struct derives_from_some_struct : some_struct {
};
} // namespace some_namespace

TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
  auto S0 = std::make_unique<some_namespace::some_struct>();
  auto S1 = std::make_unique<some_namespace::derives_from_some_struct>();
  SmallVector<some_namespace::some_struct *> V0{S0.get()};
  SmallVector<some_namespace::derives_from_some_struct *> V1{S1.get(), S1.get()};

  // Use concat over ranges of pointers to different (but related) types.
  EXPECT_THAT(concat<some_namespace::some_struct *>(V0, V1),
             ElementsAre(S0.get(),
                         static_cast<some_namespace::some_struct *>(S1.get()),
                         static_cast<some_namespace::some_struct *>(S1.get())));
}
```

>From ebdc2a3c4d61a164351892f22431d630f42054ca Mon Sep 17 00:00:00 2001
From: Javier Lopez-Gomez <javier.lopez.gomez at proton.me>
Date: Wed, 18 Jun 2025 18:14:37 +0200
Subject: [PATCH] [ADT] Fix llvm::concat_iterator for `ValueT ==
 common_base_class *`

Fix llvm::concat_iterator for the case of `ValueT` being a pointer
to a common base class to which the result of dereferencing any
iterator in `ItersT` can be casted to.
---
 llvm/include/llvm/ADT/STLExtras.h    | 11 ++++++-----
 llvm/unittests/ADT/STLExtrasTest.cpp | 16 ++++++++++++++++
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index eea06cfb99ba2..951da522a8aa2 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1030,14 +1030,15 @@ class concat_iterator
                                   std::forward_iterator_tag, ValueT> {
   using BaseT = typename concat_iterator::iterator_facade_base;
 
-  static constexpr bool ReturnsByValue =
-      !(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
+  static constexpr bool ReturnsValueOrPointer =
+      !(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...)
+      || (std::is_pointer_v<IterTs> && ...);
 
   using reference_type =
-      typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
+      typename std::conditional_t<ReturnsValueOrPointer, ValueT, ValueT &>;
 
   using handle_type =
-      typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
+      typename std::conditional_t<ReturnsValueOrPointer, std::optional<ValueT>,
                                   ValueT *>;
 
   /// We store both the current and end iterators for each concatenated
@@ -1088,7 +1089,7 @@ class concat_iterator
     if (Begin == End)
       return {};
 
-    if constexpr (ReturnsByValue)
+    if constexpr (ReturnsValueOrPointer)
       return *Begin;
     else
       return &*Begin;
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 286cfa745fd14..0e6b040a08f4a 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -398,6 +398,9 @@ struct some_struct {
   std::string swap_val;
 };
 
+struct derives_from_some_struct : some_struct {
+};
+
 std::vector<int>::const_iterator begin(const some_struct &s) {
   return s.data.begin();
 }
@@ -532,6 +535,19 @@ TEST(STLExtrasTest, ConcatRangeADL) {
   EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
 }
 
+TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
+  auto S0 = std::make_unique<some_namespace::some_struct>();
+  auto S1 = std::make_unique<some_namespace::derives_from_some_struct>();
+  SmallVector<some_namespace::some_struct *> V0{S0.get()};
+  SmallVector<some_namespace::derives_from_some_struct *> V1{S1.get(), S1.get()};
+
+  // Use concat over ranges of pointers to different (but related) types.
+  EXPECT_THAT(concat<some_namespace::some_struct *>(V0, V1),
+	      ElementsAre(S0.get(),
+			  static_cast<some_namespace::some_struct *>(S1.get()),
+			  static_cast<some_namespace::some_struct *>(S1.get())));
+}
+
 TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
   // Make sure that we use the `begin`/`end` functions from `some_namespace`,
   // using ADL.



More information about the llvm-commits mailing list