[llvm] [ADT] Fix llvm::concat_iterator for `ValueT == common_base_class *` (PR #144744)

Javier Lopez-Gomez via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 23 08:45:52 PDT 2025


https://github.com/jalopezg-git updated https://github.com/llvm/llvm-project/pull/144744

>From fd8d8e997cb0588cc4d602fcd72de6cffcf9581d Mon Sep 17 00:00:00 2001
From: Javier Lopez-Gomez <javier.lopez.gomez at proton.me>
Date: Mon, 23 Jun 2025 17:45:38 +0200
Subject: [PATCH] [ADT] Fix llvm::concat_iterator for `ValueT ==
 common_base_class *`

Fix llvm::concat_iterator for the case of `ValueT` being a pointer
to a common base class to which the result of dereferencing any
iterator in `ItersT` can be casted to.
---
 llvm/include/llvm/ADT/STLExtras.h    | 29 ++++++++++++++++++++--------
 llvm/unittests/ADT/STLExtrasTest.cpp | 29 ++++++++++++++++++++++++++++
 2 files changed, 50 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index eea06cfb99ba2..d2010e663ebed 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1032,13 +1032,22 @@ class concat_iterator
 
   static constexpr bool ReturnsByValue =
       !(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
+  static constexpr bool ReturnsConvertiblePointer =
+      std::is_pointer_v<ValueT> &&
+      (std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);
 
   using reference_type =
-      typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
-
-  using handle_type =
-      typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
-                                  ValueT *>;
+      typename std::conditional_t<ReturnsByValue || ReturnsConvertiblePointer,
+                                  ValueT, ValueT &>;
+
+  using optional_value_type =
+      std::conditional_t<ReturnsByValue, std::optional<ValueT>, ValueT *>;
+  // handle_type is used to return an optional value from `getHelper()`. If
+  // the type resulting from dereferencing all IterTs is a pointer that can be
+  // converted to `ValueT`, use that pointer type instead to avoid implicit
+  // conversion issues.
+  using handle_type = typename std::conditional_t<ReturnsConvertiblePointer,
+                                                  ValueT, optional_value_type>;
 
   /// We store both the current and end iterators for each concatenated
   /// sequence in a tuple of pairs.
@@ -1088,7 +1097,7 @@ class concat_iterator
     if (Begin == End)
       return {};
 
-    if constexpr (ReturnsByValue)
+    if constexpr (ReturnsByValue || ReturnsConvertiblePointer)
       return *Begin;
     else
       return &*Begin;
@@ -1105,8 +1114,12 @@ class concat_iterator
 
     // Loop over them, and return the first result we find.
     for (auto &GetHelperFn : GetHelperFns)
-      if (auto P = (this->*GetHelperFn)())
-        return *P;
+      if (auto P = (this->*GetHelperFn)()) {
+        if constexpr (ReturnsConvertiblePointer)
+          return P;
+        else
+          return *P;
+      }
 
     llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
   }
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 286cfa745fd14..6d893c2295819 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -398,6 +398,8 @@ struct some_struct {
   std::string swap_val;
 };
 
+struct derives_from_some_struct : some_struct {};
+
 std::vector<int>::const_iterator begin(const some_struct &s) {
   return s.data.begin();
 }
@@ -532,6 +534,33 @@ TEST(STLExtrasTest, ConcatRangeADL) {
   EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
 }
 
+TEST(STLExtrasTest, ConcatRangeRef) {
+  SmallVector<some_namespace::some_struct> V12{{{1, 2}, "V12[0]"}};
+  SmallVector<some_namespace::some_struct> V3456{{{3, 4}, "V3456[0]"},
+                                                 {{5, 6}, "V3456[1]"}};
+
+  // Use concat with `iterator type = some_namespace::some_struct *` and value
+  // being a reference type.
+  std::vector<some_namespace::some_struct *> Expected = {&V12[0], &V3456[0],
+                                                         &V3456[1]};
+  std::vector<some_namespace::some_struct *> Test;
+  for (auto &i : concat<some_namespace::some_struct>(V12, V3456))
+    Test.push_back(&i);
+  EXPECT_EQ(Expected, Test);
+}
+
+TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
+  some_namespace::some_struct S0{};
+  some_namespace::derives_from_some_struct S1{};
+  SmallVector<some_namespace::some_struct *> V0{&S0};
+  SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};
+
+  // Use concat over ranges of pointers to different (but related) types.
+  EXPECT_THAT(concat<some_namespace::some_struct *>(V0, V1),
+              ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
+                          static_cast<some_namespace::some_struct *>(&S1)));
+}
+
 TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
   // Make sure that we use the `begin`/`end` functions from `some_namespace`,
   // using ADL.



More information about the llvm-commits mailing list