[llvm] bf738d2 - [ADT] Make mapped_iterator copy assignable
Kazu Hirata via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 30 14:37:14 PDT 2022
Author: James Player
Date: 2022-10-30T14:37:08-07:00
New Revision: bf738d2e77846826964402f2cccdd0681c71c038
URL: https://github.com/llvm/llvm-project/commit/bf738d2e77846826964402f2cccdd0681c71c038
DIFF: https://github.com/llvm/llvm-project/commit/bf738d2e77846826964402f2cccdd0681c71c038.diff
LOG: [ADT] Make mapped_iterator copy assignable
As mentioned in https://discourse.llvm.org/t/rfc-extend-ranges-infrastructure-to-better-match-c-20/65377
Lambda objects are not copy assignable, and therefore neither are
iterator types which hold a lambda. STL code require iterators be
copy assignable. Users may not use mapped_iterator with a std::deque
for example: https://godbolt.org/z/4Px7odEEd
This blog post [1] explains the problem and solution. We define a
wrapper class to store callable objects with two specialization.
1. Specialization for non-function types
- Use a std::optional as storage for non-function callable.
- Define operator=() implementation(s) which use
std::optional::emplace() instead of the assignment operator.
2. Specialization for function types
- Store as a pointer (even if template argument is a function reference).
- Default construct pointer to nullptr.
This Callable wrapper class is now default constructible (with invalid
state) and copy/move assignable.
With these new properties available on the callable object,
mapped_iterator can define a default constructor as well.
[1] https://www.fluentcpp.com/2019/04/16/an-alternative-design-to-iterators-and-ranges-using-stdoptional/
Reviewed By: kazu
Differential Revision: https://reviews.llvm.org/D134675
Added:
Modified:
llvm/include/llvm/ADT/STLExtras.h
llvm/unittests/ADT/MappedIteratorTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 1fe6609986f6a..3a89b3f30bbd9 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -36,6 +36,7 @@
#include <iterator>
#include <limits>
#include <memory>
+#include <optional>
#include <tuple>
#include <type_traits>
#include <utility>
@@ -208,6 +209,131 @@ constexpr auto addEnumValues(EnumTy1 LHS, EnumTy2 RHS) {
// Extra additions to <iterator>
//===----------------------------------------------------------------------===//
+namespace callable_detail {
+
+/// Templated storage wrapper for a callable.
+///
+/// This class is consistently default constructible, copy / move
+/// constructible / assignable.
+///
+/// Supported callable types:
+/// - Function pointer
+/// - Function reference
+/// - Lambda
+/// - Function object
+template <typename T,
+ bool = std::is_function_v<std::remove_pointer_t<remove_cvref_t<T>>>>
+class Callable {
+ using value_type = std::remove_reference_t<T>;
+ using reference = value_type &;
+ using const_reference = value_type const &;
+
+ std::optional<value_type> Obj;
+
+ static_assert(!std::is_pointer_v<value_type>,
+ "Pointers to non-functions are not callable.");
+
+public:
+ Callable() = default;
+ Callable(T const &O) : Obj(std::in_place, O) {}
+
+ Callable(Callable const &Other) = default;
+ Callable(Callable &&Other) = default;
+
+ Callable &operator=(Callable const &Other) {
+ Obj = std::nullopt;
+ if (Other.Obj)
+ Obj.emplace(*Other.Obj);
+ return *this;
+ }
+
+ Callable &operator=(Callable &&Other) {
+ Obj = std::nullopt;
+ if (Other.Obj)
+ Obj.emplace(std::move(*Other.Obj));
+ return *this;
+ }
+
+ template <typename... Pn,
+ std::enable_if_t<std::is_invocable_v<T, Pn...>, int> = 0>
+ decltype(auto) operator()(Pn &&...Params) {
+ return (*Obj)(std::forward<Pn>(Params)...);
+ }
+
+ template <typename... Pn,
+ std::enable_if_t<std::is_invocable_v<T const, Pn...>, int> = 0>
+ decltype(auto) operator()(Pn &&...Params) const {
+ return (*Obj)(std::forward<Pn>(Params)...);
+ }
+
+ bool valid() const { return Obj != std::nullopt; }
+ bool reset() { return Obj = std::nullopt; }
+
+ operator reference() { return *Obj; }
+ operator const_reference() const { return *Obj; }
+};
+
+// Function specialization. No need to waste extra space wrapping with a
+// std::optional.
+template <typename T> class Callable<T, true> {
+ static constexpr bool IsPtr = std::is_pointer_v<remove_cvref_t<T>>;
+
+ using StorageT = std::conditional_t<IsPtr, T, std::remove_reference_t<T> *>;
+ using CastT = std::conditional_t<IsPtr, T, T &>;
+
+private:
+ StorageT Func = nullptr;
+
+private:
+ template <typename In> static constexpr auto convertIn(In &&I) {
+ if constexpr (IsPtr) {
+ // Pointer... just echo it back.
+ return I;
+ } else {
+ // Must be a function reference. Return its address.
+ return &I;
+ }
+ }
+
+public:
+ Callable() = default;
+
+ // Construct from a function pointer or reference.
+ //
+ // Disable this constructor for references to 'Callable' so we don't violate
+ // the rule of 0.
+ template < // clang-format off
+ typename FnPtrOrRef,
+ std::enable_if_t<
+ !std::is_same_v<remove_cvref_t<FnPtrOrRef>, Callable>, int
+ > = 0
+ > // clang-format on
+ Callable(FnPtrOrRef &&F) : Func(convertIn(F)) {}
+
+ template <typename... Pn,
+ std::enable_if_t<std::is_invocable_v<T, Pn...>, int> = 0>
+ decltype(auto) operator()(Pn &&...Params) const {
+ return Func(std::forward<Pn>(Params)...);
+ }
+
+ bool valid() const { return Func != nullptr; }
+ void reset() { Func = nullptr; }
+
+ operator T const &() const {
+ if constexpr (IsPtr) {
+ // T is a pointer... just echo it back.
+ return Func;
+ } else {
+ static_assert(std::is_reference_v<T>,
+ "Expected a reference to a function.");
+ // T is a function reference... dereference the stored pointer.
+ return *Func;
+ }
+ }
+};
+
+} // namespace callable_detail
+
namespace adl_detail {
using std::begin;
@@ -291,6 +417,7 @@ class mapped_iterator
typename std::iterator_traits<ItTy>::
diff erence_type,
std::remove_reference_t<ReferenceTy> *, ReferenceTy> {
public:
+ mapped_iterator() = default;
mapped_iterator(ItTy U, FuncTy F)
: mapped_iterator::iterator_adaptor_base(std::move(U)), F(std::move(F)) {}
@@ -301,7 +428,7 @@ class mapped_iterator
ReferenceTy operator*() const { return F(*this->I); }
private:
- FuncTy F;
+ callable_detail::Callable<FuncTy> F{};
};
// map_iterator - Provide a convenient way to create mapped_iterators, just like
diff --git a/llvm/unittests/ADT/MappedIteratorTest.cpp b/llvm/unittests/ADT/MappedIteratorTest.cpp
index f94709805c2cd..ca54cb30f9560 100644
--- a/llvm/unittests/ADT/MappedIteratorTest.cpp
+++ b/llvm/unittests/ADT/MappedIteratorTest.cpp
@@ -13,10 +13,201 @@ using namespace llvm;
namespace {
-TEST(MappedIteratorTest, ApplyFunctionOnDereference) {
+template <typename T> class MappedIteratorTestBasic : public testing::Test {};
+
+struct Plus1Lambda {
+ auto operator()() const {
+ return [](int X) { return X + 1; };
+ }
+};
+
+struct Plus1LambdaWithCapture {
+ const int One = 1;
+
+ auto operator()() const {
+ return [=](int X) { return X + One; };
+ }
+};
+
+struct Plus1FunctionRef {
+ static int plus1(int X) { return X + 1; }
+
+ using FuncT = int (&)(int);
+
+ FuncT operator()() const { return *plus1; }
+};
+
+struct Plus1FunctionPtr {
+ static int plus1(int X) { return X + 1; }
+
+ using FuncT = int (*)(int);
+
+ FuncT operator()() const { return plus1; }
+};
+
+struct Plus1Functor {
+ struct Plus1 {
+ int operator()(int X) const { return X + 1; }
+ };
+
+ auto operator()() const { return Plus1(); }
+};
+
+struct Plus1FunctorNotDefaultConstructible {
+ class PlusN {
+ const int N;
+
+ public:
+ PlusN(int NArg) : N(NArg) {}
+
+ int operator()(int X) const { return X + N; }
+ };
+
+ auto operator()() const { return PlusN(1); }
+};
+
+// clang-format off
+using FunctionTypes =
+ ::testing::Types<
+ Plus1Lambda,
+ Plus1LambdaWithCapture,
+ Plus1FunctionRef,
+ Plus1FunctionPtr,
+ Plus1Functor,
+ Plus1FunctorNotDefaultConstructible
+ >;
+// clang-format on
+
+TYPED_TEST_SUITE(MappedIteratorTestBasic, FunctionTypes, );
+
+template <typename T> using GetFuncT = decltype(std::declval<T>().operator()());
+
+TYPED_TEST(MappedIteratorTestBasic, DefaultConstruct) {
+ using FuncT = GetFuncT<TypeParam>;
+ using IterT = mapped_iterator<typename std::vector<int>::iterator, FuncT>;
+ TypeParam GetCallable;
+
+ auto Func = GetCallable();
+ (void)Func;
+ constexpr bool DefaultConstruct =
+ std::is_default_constructible_v<callable_detail::Callable<FuncT>>;
+ EXPECT_TRUE(DefaultConstruct);
+ EXPECT_TRUE(std::is_default_constructible_v<IterT>);
+
+ if constexpr (std::is_default_constructible_v<IterT>) {
+ IterT I;
+ (void)I;
+ }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, CopyConstruct) {
+ std::vector<int> V({0});
+
+ using FuncT = GetFuncT<TypeParam>;
+ using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+ EXPECT_TRUE(std::is_copy_constructible_v<IterT>);
+
+ if constexpr (std::is_copy_constructible_v<IterT>) {
+ TypeParam GetCallable;
+
+ IterT I1(V.begin(), GetCallable());
+ IterT I2(I1);
+
+ EXPECT_EQ(I2, I1) << "copy constructed iterator is a
diff erent position";
+ }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, MoveConstruct) {
+ std::vector<int> V({0});
+
+ using FuncT = GetFuncT<TypeParam>;
+ using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+ EXPECT_TRUE(std::is_move_constructible_v<IterT>);
+
+ if constexpr (std::is_move_constructible_v<IterT>) {
+ TypeParam GetCallable;
+
+ IterT I1(V.begin(), GetCallable());
+ IterT I2(V.begin(), GetCallable());
+ IterT I3(std::move(I2));
+
+ EXPECT_EQ(I3, I1) << "move constructed iterator is a
diff erent position";
+ }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, CopyAssign) {
std::vector<int> V({0});
- auto I = map_iterator(V.begin(), [](int X) { return X + 1; });
+ using FuncT = GetFuncT<TypeParam>;
+ using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+ EXPECT_TRUE(std::is_copy_assignable_v<IterT>);
+
+ if constexpr (std::is_copy_assignable_v<IterT>) {
+ TypeParam GetCallable;
+
+ IterT I1(V.begin(), GetCallable());
+ IterT I2(V.end(), GetCallable());
+
+ I2 = I1;
+
+ EXPECT_EQ(I2, I1) << "copy assigned iterator is a
diff erent position";
+ }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, MoveAssign) {
+ std::vector<int> V({0});
+
+ using FuncT = GetFuncT<TypeParam>;
+ using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+ EXPECT_TRUE(std::is_move_assignable_v<IterT>);
+
+ if constexpr (std::is_move_assignable_v<IterT>) {
+ TypeParam GetCallable;
+
+ IterT I1(V.begin(), GetCallable());
+ IterT I2(V.begin(), GetCallable());
+ IterT I3(V.end(), GetCallable());
+
+ I3 = std::move(I2);
+
+ EXPECT_EQ(I2, I1) << "move assigned iterator is a
diff erent position";
+ }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, GetFunction) {
+ std::vector<int> V({0});
+
+ using FuncT = GetFuncT<TypeParam>;
+ using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+ TypeParam GetCallable;
+ IterT I(V.begin(), GetCallable());
+
+ EXPECT_EQ(I.getFunction()(200), 201);
+}
+
+TYPED_TEST(MappedIteratorTestBasic, GetCurrent) {
+ std::vector<int> V({0});
+
+ using FuncT = GetFuncT<TypeParam>;
+ using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+ TypeParam GetCallable;
+ IterT I(V.begin(), GetCallable());
+
+ EXPECT_EQ(I.getCurrent(), V.begin());
+ EXPECT_EQ(std::next(I).getCurrent(), V.end());
+}
+
+TYPED_TEST(MappedIteratorTestBasic, ApplyFunctionOnDereference) {
+ std::vector<int> V({0});
+ TypeParam GetCallable;
+
+ auto I = map_iterator(V.begin(), GetCallable());
EXPECT_EQ(*I, 1) << "should have applied function in dereference";
}
@@ -28,9 +219,9 @@ TEST(MappedIteratorTest, ApplyFunctionOnArrow) {
std::vector<int> V({0});
S Y;
- S* P = &Y;
+ S *P = &Y;
- auto I = map_iterator(V.begin(), [&](int X) -> S& { return *(P + X); });
+ auto I = map_iterator(V.begin(), [&](int X) -> S & { return *(P + X); });
I->Z = 42;
@@ -39,9 +230,9 @@ TEST(MappedIteratorTest, ApplyFunctionOnArrow) {
TEST(MappedIteratorTest, FunctionPreservesReferences) {
std::vector<int> V({1});
- std::map<int, int> M({ {1, 1} });
+ std::map<int, int> M({{1, 1}});
- auto I = map_iterator(V.begin(), [&](int X) -> int& { return M[X]; });
+ auto I = map_iterator(V.begin(), [&](int X) -> int & { return M[X]; });
*I = 42;
EXPECT_EQ(M[1], 42) << "assignment should have modified M";
More information about the llvm-commits
mailing list