[llvm] 205701f - [llvm][ADT] Allow using structured bindings with `llvm::enumerate`

Markus Böck via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 9 09:12:47 PDT 2022


Author: Markus Böck
Date: 2022-08-09T18:12:40+02:00
New Revision: 205701fd47535a7525789f3291bd686cef1a9773

URL: https://github.com/llvm/llvm-project/commit/205701fd47535a7525789f3291bd686cef1a9773
DIFF: https://github.com/llvm/llvm-project/commit/205701fd47535a7525789f3291bd686cef1a9773.diff

LOG: [llvm][ADT] Allow using structured bindings with `llvm::enumerate`

This patch adds the ability to deconstruct the `value_type` returned by `llvm::enumarate` into index and value of the wrapping range. Main use case is the common occurence of using it during loop iteration. After this patch it'd then be possible to write code such as:
```
for (auto [index, value] : enumerate(container)) {
   ...
}
```
where `index` is the current index and `value` a reference to elements in the given container.

Differential Revision: https://reviews.llvm.org/D131486

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/unittests/ADT/STLExtrasTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 50f2cdf666ace..dd8b2095ba7ee 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1982,6 +1982,16 @@ template <typename R> struct result_pair {
   IterOfRange<R> Iter;
 };
 
+template <std::size_t i, typename R>
+decltype(auto) get(const result_pair<R> &Pair) {
+  static_assert(i < 2);
+  if constexpr (i == 0) {
+    return Pair.index();
+  } else {
+    return Pair.value();
+  }
+}
+
 template <typename R>
 class enumerator_iter
     : public iterator_facade_base<enumerator_iter<R>, std::forward_iterator_tag,
@@ -2054,6 +2064,12 @@ template <typename R> class enumerator {
 ///   printf("Item %d - %c\n", X.index(), X.value());
 /// }
 ///
+/// or using structured bindings:
+///
+/// for (auto [Index, Value] : enumerate(Items)) {
+///   printf("Item %d - %c\n", Index, Value);
+/// }
+///
 /// Output:
 ///   Item 0 - A
 ///   Item 1 - B
@@ -2192,4 +2208,17 @@ template <class T> constexpr T *to_address(T *P) { return P; }
 
 } // end namespace llvm
 
+namespace std {
+template <typename R>
+struct tuple_size<llvm::detail::result_pair<R>>
+    : std::integral_constant<std::size_t, 2> {};
+
+template <std::size_t i, typename R>
+struct tuple_element<i, llvm::detail::result_pair<R>>
+    : std::conditional<i == 0, std::size_t,
+                       typename llvm::detail::result_pair<R>::value_reference> {
+};
+
+} // namespace std
+
 #endif // LLVM_ADT_STLEXTRAS_H

diff  --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index bb31b267926f9..8ae50202cbd6c 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/STLExtras.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 #include <climits>
@@ -15,6 +16,8 @@
 
 using namespace llvm;
 
+using testing::ElementsAre;
+
 namespace {
 
 int f(rank<0>) { return 0; }
@@ -47,31 +50,29 @@ TEST(STLExtrasTest, EnumerateLValue) {
   typedef std::pair<std::size_t, char> CharPairType;
   std::vector<CharPairType> CharResults;
 
-  for (auto X : llvm::enumerate(foo)) {
-    CharResults.emplace_back(X.index(), X.value());
+  for (auto [index, value] : llvm::enumerate(foo)) {
+    CharResults.emplace_back(index, value);
   }
-  ASSERT_EQ(3u, CharResults.size());
-  EXPECT_EQ(CharPairType(0u, 'a'), CharResults[0]);
-  EXPECT_EQ(CharPairType(1u, 'b'), CharResults[1]);
-  EXPECT_EQ(CharPairType(2u, 'c'), CharResults[2]);
+
+  EXPECT_THAT(CharResults,
+              ElementsAre(CharPairType(0u, 'a'), CharPairType(1u, 'b'),
+                          CharPairType(2u, 'c')));
 
   // Test a const range of a 
diff erent type.
   typedef std::pair<std::size_t, int> IntPairType;
   std::vector<IntPairType> IntResults;
   const std::vector<int> bar = {1, 2, 3};
-  for (auto X : llvm::enumerate(bar)) {
-    IntResults.emplace_back(X.index(), X.value());
+  for (auto [index, value] : llvm::enumerate(bar)) {
+    IntResults.emplace_back(index, value);
   }
-  ASSERT_EQ(3u, IntResults.size());
-  EXPECT_EQ(IntPairType(0u, 1), IntResults[0]);
-  EXPECT_EQ(IntPairType(1u, 2), IntResults[1]);
-  EXPECT_EQ(IntPairType(2u, 3), IntResults[2]);
+  EXPECT_THAT(IntResults, ElementsAre(IntPairType(0u, 1), IntPairType(1u, 2),
+                                      IntPairType(2u, 3)));
 
   // Test an empty range.
   IntResults.clear();
   const std::vector<int> baz{};
-  for (auto X : llvm::enumerate(baz)) {
-    IntResults.emplace_back(X.index(), X.value());
+  for (auto [index, value] : llvm::enumerate(baz)) {
+    IntResults.emplace_back(index, value);
   }
   EXPECT_TRUE(IntResults.empty());
 }
@@ -84,9 +85,15 @@ TEST(STLExtrasTest, EnumerateModifyLValue) {
   for (auto X : llvm::enumerate(foo)) {
     ++X.value();
   }
-  EXPECT_EQ('b', foo[0]);
-  EXPECT_EQ('c', foo[1]);
-  EXPECT_EQ('d', foo[2]);
+  EXPECT_THAT(foo, ElementsAre('b', 'c', 'd'));
+
+  // Also test if this works with structured bindings.
+  foo = {'a', 'b', 'c'};
+
+  for (auto [index, value] : llvm::enumerate(foo)) {
+    ++value;
+  }
+  EXPECT_THAT(foo, ElementsAre('b', 'c', 'd'));
 }
 
 TEST(STLExtrasTest, EnumerateRValueRef) {
@@ -100,10 +107,18 @@ TEST(STLExtrasTest, EnumerateRValueRef) {
     Results.emplace_back(X.index(), X.value());
   }
 
-  ASSERT_EQ(3u, Results.size());
-  EXPECT_EQ(PairType(0u, 1), Results[0]);
-  EXPECT_EQ(PairType(1u, 2), Results[1]);
-  EXPECT_EQ(PairType(2u, 3), Results[2]);
+  EXPECT_THAT(Results,
+              ElementsAre(PairType(0u, 1), PairType(1u, 2), PairType(2u, 3)));
+
+  // Also test if this works with structured bindings.
+  Results.clear();
+
+  for (auto [index, value] : llvm::enumerate(std::vector<int>{1, 2, 3})) {
+    Results.emplace_back(index, value);
+  }
+
+  EXPECT_THAT(Results,
+              ElementsAre(PairType(0u, 1), PairType(1u, 2), PairType(2u, 3)));
 }
 
 TEST(STLExtrasTest, EnumerateModifyRValue) {
@@ -118,10 +133,20 @@ TEST(STLExtrasTest, EnumerateModifyRValue) {
     Results.emplace_back(X.index(), X.value());
   }
 
-  ASSERT_EQ(3u, Results.size());
-  EXPECT_EQ(PairType(0u, '2'), Results[0]);
-  EXPECT_EQ(PairType(1u, '3'), Results[1]);
-  EXPECT_EQ(PairType(2u, '4'), Results[2]);
+  EXPECT_THAT(Results, ElementsAre(PairType(0u, '2'), PairType(1u, '3'),
+                                   PairType(2u, '4')));
+
+  // Also test if this works with structured bindings.
+  Results.clear();
+
+  for (auto [index, value] :
+       llvm::enumerate(std::vector<char>{'1', '2', '3'})) {
+    ++value;
+    Results.emplace_back(index, value);
+  }
+
+  EXPECT_THAT(Results, ElementsAre(PairType(0u, '2'), PairType(1u, '3'),
+                                   PairType(2u, '4')));
 }
 
 template <bool B> struct CanMove {};


        


More information about the llvm-commits mailing list