[llvm] 981ce8f - [ADT] Fix const-correctness issues in `zippy`
Jakub Kuderski via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 28 13:29:04 PST 2023
Author: Jakub Kuderski
Date: 2023-02-28T16:24:51-05:00
New Revision: 981ce8fa15afa11d083033240edb1daff29081c7
URL: https://github.com/llvm/llvm-project/commit/981ce8fa15afa11d083033240edb1daff29081c7
DIFF: https://github.com/llvm/llvm-project/commit/981ce8fa15afa11d083033240edb1daff29081c7.diff
LOG: [ADT] Fix const-correctness issues in `zippy`
This defines the iterator tuple based on the storage type of `zippy`,
instead of its type arguments. This way, we can support temporaries that
gets passed in and allow for them to be modified during iteration.
Because the iterator types to the tuple storage can have different types
when the storage is and isn't const, this defines a const iterator type
and non-const `begin`/`end` functions. This way we avoid unintentional
casts, e.g., trying to cast `vector<bool>::reference` to
`vector<bool>::const_reference`, which may be unrelated types that are
not convertible.
This patch is a general and free-standing improvement but my primary use
is in the implemention a version of `enumerate` that accepts multiple ranges:
D144583.
Reviewed By: dblaikie, zero9178
Differential Revision: https://reviews.llvm.org/D144834
Added:
Modified:
llvm/include/llvm/ADT/STLExtras.h
llvm/unittests/ADT/IteratorTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index f1a6587ecc7f3..86d80354c9978 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -856,33 +856,70 @@ class zip_shortest : public zip_common<zip_shortest<Iters...>, Iters...> {
}
};
+/// Helper to obtain the iterator types for the tuple storage within `zippy`.
+template <template <typename...> class ItType, typename TupleStorageType,
+ typename IndexSequence>
+struct ZippyIteratorTuple;
+
+/// Partial specialization for non-const tuple storage.
+template <template <typename...> class ItType, typename... Args,
+ std::size_t... Ns>
+struct ZippyIteratorTuple<ItType, std::tuple<Args...>,
+ std::index_sequence<Ns...>> {
+ using type = ItType<decltype(adl_begin(
+ std::get<Ns>(declval<std::tuple<Args...> &>())))...>;
+};
+
+/// Partial specialization for const tuple storage.
+template <template <typename...> class ItType, typename... Args,
+ std::size_t... Ns>
+struct ZippyIteratorTuple<ItType, const std::tuple<Args...>,
+ std::index_sequence<Ns...>> {
+ using type = ItType<decltype(adl_begin(
+ std::get<Ns>(declval<const std::tuple<Args...> &>())))...>;
+};
+
template <template <typename...> class ItType, typename... Args> class zippy {
+private:
+ std::tuple<Args...> storage;
+ using IndexSequence = std::index_sequence_for<Args...>;
+
public:
- using iterator = ItType<decltype(std::begin(std::declval<Args>()))...>;
+ using iterator = typename ZippyIteratorTuple<ItType, decltype(storage),
+ IndexSequence>::type;
+ using const_iterator =
+ typename ZippyIteratorTuple<ItType, const decltype(storage),
+ IndexSequence>::type;
using iterator_category = typename iterator::iterator_category;
using value_type = typename iterator::value_type;
using
diff erence_type = typename iterator::
diff erence_type;
using pointer = typename iterator::pointer;
using reference = typename iterator::reference;
+ using const_reference = typename const_iterator::reference;
-private:
- std::tuple<Args...> ts;
+ zippy(Args &&...args) : storage(std::forward<Args>(args)...) {}
+ const_iterator begin() const { return begin_impl(IndexSequence{}); }
+ iterator begin() { return begin_impl(IndexSequence{}); }
+ const_iterator end() const { return end_impl(IndexSequence{}); }
+ iterator end() { return end_impl(IndexSequence{}); }
+
+private:
template <size_t... Ns>
- iterator begin_impl(std::index_sequence<Ns...>) const {
- return iterator(std::begin(std::get<Ns>(ts))...);
+ const_iterator begin_impl(std::index_sequence<Ns...>) const {
+ return const_iterator(adl_begin(std::get<Ns>(storage))...);
}
- template <size_t... Ns> iterator end_impl(std::index_sequence<Ns...>) const {
- return iterator(std::end(std::get<Ns>(ts))...);
+ template <size_t... Ns> iterator begin_impl(std::index_sequence<Ns...>) {
+ return iterator(adl_begin(std::get<Ns>(storage))...);
}
-public:
- zippy(Args &&... ts_) : ts(std::forward<Args>(ts_)...) {}
-
- iterator begin() const {
- return begin_impl(std::index_sequence_for<Args...>{});
+ template <size_t... Ns>
+ const_iterator end_impl(std::index_sequence<Ns...>) const {
+ return const_iterator(adl_end(std::get<Ns>(storage))...);
+ }
+ template <size_t... Ns> iterator end_impl(std::index_sequence<Ns...>) {
+ return iterator(adl_end(std::get<Ns>(storage))...);
}
- iterator end() const { return end_impl(std::index_sequence_for<Args...>{}); }
};
} // end namespace detail
diff --git a/llvm/unittests/ADT/IteratorTest.cpp b/llvm/unittests/ADT/IteratorTest.cpp
index 641ff4fcc7f54..b2a11c4c6bd7d 100644
--- a/llvm/unittests/ADT/IteratorTest.cpp
+++ b/llvm/unittests/ADT/IteratorTest.cpp
@@ -6,15 +6,19 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/ADT/ilist.h"
#include "llvm/ADT/iterator.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/ilist.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <optional>
+#include <type_traits>
+#include <vector>
using namespace llvm;
+using testing::ElementsAre;
namespace {
@@ -430,6 +434,108 @@ TEST(ZipIteratorTest, ZipEqualBasic) {
EXPECT_EQ(iters, 6u);
}
+template <typename T>
+constexpr bool IsConstRef =
+ std::is_reference_v<T> && std::is_const_v<std::remove_reference_t<T>>;
+
+template <typename T>
+constexpr bool IsBoolConstRef =
+ std::is_same_v<llvm::remove_cvref_t<T>, std::vector<bool>::const_reference>;
+
+/// Returns a `const` copy of the passed value. The `const` on the returned
+/// value is intentional here so that `MakeConst` can be used in range-for
+/// loops.
+template <typename T> const T MakeConst(T &&value) {
+ return std::forward<T>(value);
+}
+
+TEST(ZipIteratorTest, ZipEqualConstCorrectness) {
+ const std::vector<unsigned> c_first = {3, 1, 4};
+ std::vector<unsigned> first = c_first;
+ const SmallVector<bool> c_second = {1, 1, 0};
+ SmallVector<bool> second = c_second;
+
+ for (auto [a, b, c, d] : zip_equal(c_first, first, c_second, second)) {
+ b = 0;
+ d = true;
+ static_assert(IsConstRef<decltype(a)>);
+ static_assert(!IsConstRef<decltype(b)>);
+ static_assert(IsConstRef<decltype(c)>);
+ static_assert(!IsConstRef<decltype(d)>);
+ }
+
+ EXPECT_THAT(first, ElementsAre(0, 0, 0));
+ EXPECT_THAT(second, ElementsAre(true, true, true));
+
+ std::vector<bool> nemesis = {true, false, true};
+ const std::vector<bool> c_nemesis = nemesis;
+
+ for (auto &&[a, b, c, d] : zip_equal(first, c_first, nemesis, c_nemesis)) {
+ a = 2;
+ c = true;
+ static_assert(!IsConstRef<decltype(a)>);
+ static_assert(IsConstRef<decltype(b)>);
+ static_assert(!IsBoolConstRef<decltype(c)>);
+ static_assert(IsBoolConstRef<decltype(d)>);
+ }
+
+ EXPECT_THAT(first, ElementsAre(2, 2, 2));
+ EXPECT_THAT(nemesis, ElementsAre(true, true, true));
+
+ unsigned iters = 0;
+ for (const auto &[a, b, c, d] :
+ zip_equal(first, c_first, nemesis, c_nemesis)) {
+ static_assert(!IsConstRef<decltype(a)>);
+ static_assert(IsConstRef<decltype(b)>);
+ static_assert(!IsBoolConstRef<decltype(c)>);
+ static_assert(IsBoolConstRef<decltype(d)>);
+ ++iters;
+ }
+ EXPECT_EQ(iters, 3u);
+ iters = 0;
+
+ for (const auto &[a, b, c, d] :
+ MakeConst(zip_equal(first, c_first, nemesis, c_nemesis))) {
+ static_assert(!IsConstRef<decltype(a)>);
+ static_assert(IsConstRef<decltype(b)>);
+ static_assert(!IsBoolConstRef<decltype(c)>);
+ static_assert(IsBoolConstRef<decltype(d)>);
+ ++iters;
+ }
+ EXPECT_EQ(iters, 3u);
+}
+
+TEST(ZipIteratorTest, ZipEqualTemporaries) {
+ unsigned iters = 0;
+
+ // These temporary ranges get moved into the `tuple<...> storage;` inside
+ // `zippy`. From then on, we can use references obtained from this storage to
+ // access them. This does not rely on any lifetime extensions on the
+ // temporaries passed to `zip_equal`.
+ for (auto [a, b, c] : zip_equal(SmallVector<int>{1, 2, 3}, std::string("abc"),
+ std::vector<bool>{true, false, true})) {
+ a = 3;
+ b = 'c';
+ c = false;
+ static_assert(!IsConstRef<decltype(a)>);
+ static_assert(!IsConstRef<decltype(b)>);
+ static_assert(!IsBoolConstRef<decltype(c)>);
+ ++iters;
+ }
+ EXPECT_EQ(iters, 3u);
+ iters = 0;
+
+ for (auto [a, b, c] :
+ MakeConst(zip_equal(SmallVector<int>{1, 2, 3}, std::string("abc"),
+ std::vector<bool>{true, false, true}))) {
+ static_assert(IsConstRef<decltype(a)>);
+ static_assert(IsConstRef<decltype(b)>);
+ static_assert(IsBoolConstRef<decltype(c)>);
+ ++iters;
+ }
+ EXPECT_EQ(iters, 3u);
+}
+
#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
// Check that an assertion is triggered when ranges passed to `zip_equal`
diff er
// in length.
More information about the llvm-commits
mailing list