[llvm] b70de61 - Add `all_of_zip` to STLExtras

Mehdi Amini via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 28 22:00:45 PDT 2021


Author: Mehdi Amini
Date: 2021-07-29T05:00:35Z
New Revision: b70de61f48062c7810b474bc944394ecbd56a262

URL: https://github.com/llvm/llvm-project/commit/b70de61f48062c7810b474bc944394ecbd56a262
DIFF: https://github.com/llvm/llvm-project/commit/b70de61f48062c7810b474bc944394ecbd56a262.diff

LOG: Add `all_of_zip` to STLExtras

This takes two ranges and invokes a predicate on the element-wise pair in the
ranges. It returns true if all the pairs are matching the predicate and the ranges
have the same size.
It is useful with containers that aren't random iterator where we can't check the
sizes in O(1).

Differential Revision: https://reviews.llvm.org/D106605

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/unittests/ADT/STLExtrasTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index eb001346b6093..1ecc678e37a1f 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -630,6 +630,14 @@ struct zip_common : public zip_traits<ZipType, Iters...> {
     return std::tuple<Iters...>(std::prev(std::get<Ns>(iterators))...);
   }
 
+  template <size_t... Ns>
+  bool test_all_equals(const zip_common &other,
+            std::index_sequence<Ns...>) const {
+    return all_of(std::initializer_list<bool>{std::get<Ns>(this->iterators) ==
+                                              std::get<Ns>(other.iterators)...},
+                  identity<bool>{});
+  }
+
 public:
   zip_common(Iters &&... ts) : iterators(std::forward<Iters>(ts)...) {}
 
@@ -650,6 +658,11 @@ struct zip_common : public zip_traits<ZipType, Iters...> {
     iterators = tup_dec(std::index_sequence_for<Iters...>{});
     return *reinterpret_cast<ZipType *>(this);
   }
+
+  /// Return true if all the iterator are matching `other`'s iterators.
+  bool all_equals(zip_common &other) {
+    return test_all_equals(other, std::index_sequence_for<Iters...>{});
+  }
 };
 
 template <typename... Iters>
@@ -1986,6 +1999,45 @@ decltype(auto) apply_tuple(F &&f, Tuple &&t) {
                                   Indices{});
 }
 
+namespace detail {
+
+template <typename Predicate, typename... Args>
+bool all_of_zip_predicate_first(Predicate &&P, Args &&...args) {
+  auto z = zip(args...);
+  auto it = z.begin();
+  auto end = z.end();
+  while (it != end) {
+    if (!apply_tuple([&](auto &&...args) { return P(args...); }, *it))
+      return false;
+    ++it;
+  }
+  return it.all_equals(end);
+}
+
+// Just an adaptor to switch the order of argument and have the predicate before
+// the zipped inputs.
+template <typename... ArgsThenPredicate, size_t... InputIndexes>
+bool all_of_zip_predicate_last(
+    std::tuple<ArgsThenPredicate...> argsThenPredicate,
+    std::index_sequence<InputIndexes...>) {
+  auto constexpr OutputIndex =
+      std::tuple_size<decltype(argsThenPredicate)>::value - 1;
+  return all_of_zip_predicate_first(std::get<OutputIndex>(argsThenPredicate),
+                             std::get<InputIndexes>(argsThenPredicate)...);
+}
+
+} // end namespace detail
+
+/// Compare two zipped ranges using the provided predicate (as last argument).
+/// Return true if all elements satisfy the predicate and false otherwise.
+//  Return false if the zipped iterator aren't all at end (size mismatch).
+template <typename... ArgsAndPredicate>
+bool all_of_zip(ArgsAndPredicate &&...argsAndPredicate) {
+  return detail::all_of_zip_predicate_last(
+      std::forward_as_tuple(argsAndPredicate...),
+      std::make_index_sequence<sizeof...(argsAndPredicate) - 1>{});
+}
+
 /// Return true if the sequence [Begin, End) has exactly N items. Runs in O(N)
 /// time. Not meant for use with random-access iterators.
 /// Can optionally take a predicate to filter lazily some items.

diff  --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 85208e4f4a2f8..de2c8968aac17 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -876,4 +876,26 @@ TEST(STLExtrasTest, MakeVisitorLifetimeSemanticsLValue) {
   EXPECT_EQ(2, Destructors);
 }
 
+TEST(STLExtrasTest, AllOfZip) {
+  std::vector<int> v1 = {0, 4, 2, 1};
+  std::vector<int> v2 = {1, 4, 3, 6};
+  EXPECT_TRUE(all_of_zip(v1, v2, [](int v1, int v2) { return v1 <= v2; }));
+  EXPECT_FALSE(all_of_zip(v1, v2, [](int L, int R) { return L < R; }));
+
+  // Triple vectors
+  std::vector<int> v3 = {1, 6, 5, 7};
+  EXPECT_EQ(true, all_of_zip(v1, v2, v3, [](int a, int b, int c) {
+              return a <= b && b <= c;
+            }));
+  EXPECT_EQ(false, all_of_zip(v1, v2, v3, [](int a, int b, int c) {
+              return a < b && b < c;
+            }));
+
+  // Shorter vector should fail even with an always-true predicate.
+  std::vector<int> v_short = {1, 4};
+  EXPECT_EQ(false, all_of_zip(v1, v_short, [](int, int) { return true; }));
+  EXPECT_EQ(false,
+            all_of_zip(v1, v2, v_short, [](int, int, int) { return true; }));
+}
+
 } // namespace


        


More information about the llvm-commits mailing list