[llvm] bf738d2 - [ADT] Make mapped_iterator copy assignable

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 30 14:37:14 PDT 2022


Author: James Player
Date: 2022-10-30T14:37:08-07:00
New Revision: bf738d2e77846826964402f2cccdd0681c71c038

URL: https://github.com/llvm/llvm-project/commit/bf738d2e77846826964402f2cccdd0681c71c038
DIFF: https://github.com/llvm/llvm-project/commit/bf738d2e77846826964402f2cccdd0681c71c038.diff

LOG: [ADT] Make mapped_iterator copy assignable

As mentioned in https://discourse.llvm.org/t/rfc-extend-ranges-infrastructure-to-better-match-c-20/65377

Lambda objects are not copy assignable, and therefore neither are
iterator types which hold a lambda.  STL code require iterators be
copy assignable.  Users may not use mapped_iterator with a std::deque
for example: https://godbolt.org/z/4Px7odEEd

This blog post [1] explains the problem and solution.  We define a
wrapper class to store callable objects with two specialization.

1. Specialization for non-function types
    - Use a std::optional as storage for non-function callable.
    - Define operator=() implementation(s) which use
      std::optional::emplace() instead of the assignment operator.
2. Specialization for function types
    - Store as a pointer (even if template argument is a function reference).
    - Default construct pointer to nullptr.

This Callable wrapper class is now default constructible (with invalid
state) and copy/move assignable.

With these new properties available on the callable object,
mapped_iterator can define a default constructor as well.

[1] https://www.fluentcpp.com/2019/04/16/an-alternative-design-to-iterators-and-ranges-using-stdoptional/

Reviewed By: kazu

Differential Revision: https://reviews.llvm.org/D134675

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/unittests/ADT/MappedIteratorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 1fe6609986f6a..3a89b3f30bbd9 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -36,6 +36,7 @@
 #include <iterator>
 #include <limits>
 #include <memory>
+#include <optional>
 #include <tuple>
 #include <type_traits>
 #include <utility>
@@ -208,6 +209,131 @@ constexpr auto addEnumValues(EnumTy1 LHS, EnumTy2 RHS) {
 //     Extra additions to <iterator>
 //===----------------------------------------------------------------------===//
 
+namespace callable_detail {
+
+/// Templated storage wrapper for a callable.
+///
+/// This class is consistently default constructible, copy / move
+/// constructible / assignable.
+///
+/// Supported callable types:
+///  - Function pointer
+///  - Function reference
+///  - Lambda
+///  - Function object
+template <typename T,
+          bool = std::is_function_v<std::remove_pointer_t<remove_cvref_t<T>>>>
+class Callable {
+  using value_type = std::remove_reference_t<T>;
+  using reference = value_type &;
+  using const_reference = value_type const &;
+
+  std::optional<value_type> Obj;
+
+  static_assert(!std::is_pointer_v<value_type>,
+                "Pointers to non-functions are not callable.");
+
+public:
+  Callable() = default;
+  Callable(T const &O) : Obj(std::in_place, O) {}
+
+  Callable(Callable const &Other) = default;
+  Callable(Callable &&Other) = default;
+
+  Callable &operator=(Callable const &Other) {
+    Obj = std::nullopt;
+    if (Other.Obj)
+      Obj.emplace(*Other.Obj);
+    return *this;
+  }
+
+  Callable &operator=(Callable &&Other) {
+    Obj = std::nullopt;
+    if (Other.Obj)
+      Obj.emplace(std::move(*Other.Obj));
+    return *this;
+  }
+
+  template <typename... Pn,
+            std::enable_if_t<std::is_invocable_v<T, Pn...>, int> = 0>
+  decltype(auto) operator()(Pn &&...Params) {
+    return (*Obj)(std::forward<Pn>(Params)...);
+  }
+
+  template <typename... Pn,
+            std::enable_if_t<std::is_invocable_v<T const, Pn...>, int> = 0>
+  decltype(auto) operator()(Pn &&...Params) const {
+    return (*Obj)(std::forward<Pn>(Params)...);
+  }
+
+  bool valid() const { return Obj != std::nullopt; }
+  bool reset() { return Obj = std::nullopt; }
+
+  operator reference() { return *Obj; }
+  operator const_reference() const { return *Obj; }
+};
+
+// Function specialization.  No need to waste extra space wrapping with a
+// std::optional.
+template <typename T> class Callable<T, true> {
+  static constexpr bool IsPtr = std::is_pointer_v<remove_cvref_t<T>>;
+
+  using StorageT = std::conditional_t<IsPtr, T, std::remove_reference_t<T> *>;
+  using CastT = std::conditional_t<IsPtr, T, T &>;
+
+private:
+  StorageT Func = nullptr;
+
+private:
+  template <typename In> static constexpr auto convertIn(In &&I) {
+    if constexpr (IsPtr) {
+      // Pointer... just echo it back.
+      return I;
+    } else {
+      // Must be a function reference.  Return its address.
+      return &I;
+    }
+  }
+
+public:
+  Callable() = default;
+
+  // Construct from a function pointer or reference.
+  //
+  // Disable this constructor for references to 'Callable' so we don't violate
+  // the rule of 0.
+  template < // clang-format off
+    typename FnPtrOrRef,
+    std::enable_if_t<
+      !std::is_same_v<remove_cvref_t<FnPtrOrRef>, Callable>, int
+    > = 0
+  > // clang-format on
+  Callable(FnPtrOrRef &&F) : Func(convertIn(F)) {}
+
+  template <typename... Pn,
+            std::enable_if_t<std::is_invocable_v<T, Pn...>, int> = 0>
+  decltype(auto) operator()(Pn &&...Params) const {
+    return Func(std::forward<Pn>(Params)...);
+  }
+
+  bool valid() const { return Func != nullptr; }
+  void reset() { Func = nullptr; }
+
+  operator T const &() const {
+    if constexpr (IsPtr) {
+      // T is a pointer... just echo it back.
+      return Func;
+    } else {
+      static_assert(std::is_reference_v<T>,
+                    "Expected a reference to a function.");
+      // T is a function reference... dereference the stored pointer.
+      return *Func;
+    }
+  }
+};
+
+} // namespace callable_detail
+
 namespace adl_detail {
 
 using std::begin;
@@ -291,6 +417,7 @@ class mapped_iterator
           typename std::iterator_traits<ItTy>::
diff erence_type,
           std::remove_reference_t<ReferenceTy> *, ReferenceTy> {
 public:
+  mapped_iterator() = default;
   mapped_iterator(ItTy U, FuncTy F)
     : mapped_iterator::iterator_adaptor_base(std::move(U)), F(std::move(F)) {}
 
@@ -301,7 +428,7 @@ class mapped_iterator
   ReferenceTy operator*() const { return F(*this->I); }
 
 private:
-  FuncTy F;
+  callable_detail::Callable<FuncTy> F{};
 };
 
 // map_iterator - Provide a convenient way to create mapped_iterators, just like

diff  --git a/llvm/unittests/ADT/MappedIteratorTest.cpp b/llvm/unittests/ADT/MappedIteratorTest.cpp
index f94709805c2cd..ca54cb30f9560 100644
--- a/llvm/unittests/ADT/MappedIteratorTest.cpp
+++ b/llvm/unittests/ADT/MappedIteratorTest.cpp
@@ -13,10 +13,201 @@ using namespace llvm;
 
 namespace {
 
-TEST(MappedIteratorTest, ApplyFunctionOnDereference) {
+template <typename T> class MappedIteratorTestBasic : public testing::Test {};
+
+struct Plus1Lambda {
+  auto operator()() const {
+    return [](int X) { return X + 1; };
+  }
+};
+
+struct Plus1LambdaWithCapture {
+  const int One = 1;
+
+  auto operator()() const {
+    return [=](int X) { return X + One; };
+  }
+};
+
+struct Plus1FunctionRef {
+  static int plus1(int X) { return X + 1; }
+
+  using FuncT = int (&)(int);
+
+  FuncT operator()() const { return *plus1; }
+};
+
+struct Plus1FunctionPtr {
+  static int plus1(int X) { return X + 1; }
+
+  using FuncT = int (*)(int);
+
+  FuncT operator()() const { return plus1; }
+};
+
+struct Plus1Functor {
+  struct Plus1 {
+    int operator()(int X) const { return X + 1; }
+  };
+
+  auto operator()() const { return Plus1(); }
+};
+
+struct Plus1FunctorNotDefaultConstructible {
+  class PlusN {
+    const int N;
+
+  public:
+    PlusN(int NArg) : N(NArg) {}
+
+    int operator()(int X) const { return X + N; }
+  };
+
+  auto operator()() const { return PlusN(1); }
+};
+
+// clang-format off
+using FunctionTypes =
+  ::testing::Types<
+    Plus1Lambda,
+    Plus1LambdaWithCapture,
+    Plus1FunctionRef,
+    Plus1FunctionPtr,
+    Plus1Functor,
+    Plus1FunctorNotDefaultConstructible
+  >;
+// clang-format on
+
+TYPED_TEST_SUITE(MappedIteratorTestBasic, FunctionTypes, );
+
+template <typename T> using GetFuncT = decltype(std::declval<T>().operator()());
+
+TYPED_TEST(MappedIteratorTestBasic, DefaultConstruct) {
+  using FuncT = GetFuncT<TypeParam>;
+  using IterT = mapped_iterator<typename std::vector<int>::iterator, FuncT>;
+  TypeParam GetCallable;
+
+  auto Func = GetCallable();
+  (void)Func;
+  constexpr bool DefaultConstruct =
+      std::is_default_constructible_v<callable_detail::Callable<FuncT>>;
+  EXPECT_TRUE(DefaultConstruct);
+  EXPECT_TRUE(std::is_default_constructible_v<IterT>);
+
+  if constexpr (std::is_default_constructible_v<IterT>) {
+    IterT I;
+    (void)I;
+  }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, CopyConstruct) {
+  std::vector<int> V({0});
+
+  using FuncT = GetFuncT<TypeParam>;
+  using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+  EXPECT_TRUE(std::is_copy_constructible_v<IterT>);
+
+  if constexpr (std::is_copy_constructible_v<IterT>) {
+    TypeParam GetCallable;
+
+    IterT I1(V.begin(), GetCallable());
+    IterT I2(I1);
+
+    EXPECT_EQ(I2, I1) << "copy constructed iterator is a 
diff erent position";
+  }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, MoveConstruct) {
+  std::vector<int> V({0});
+
+  using FuncT = GetFuncT<TypeParam>;
+  using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+  EXPECT_TRUE(std::is_move_constructible_v<IterT>);
+
+  if constexpr (std::is_move_constructible_v<IterT>) {
+    TypeParam GetCallable;
+
+    IterT I1(V.begin(), GetCallable());
+    IterT I2(V.begin(), GetCallable());
+    IterT I3(std::move(I2));
+
+    EXPECT_EQ(I3, I1) << "move constructed iterator is a 
diff erent position";
+  }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, CopyAssign) {
   std::vector<int> V({0});
 
-  auto I = map_iterator(V.begin(), [](int X) { return X + 1; });
+  using FuncT = GetFuncT<TypeParam>;
+  using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+  EXPECT_TRUE(std::is_copy_assignable_v<IterT>);
+
+  if constexpr (std::is_copy_assignable_v<IterT>) {
+    TypeParam GetCallable;
+
+    IterT I1(V.begin(), GetCallable());
+    IterT I2(V.end(), GetCallable());
+
+    I2 = I1;
+
+    EXPECT_EQ(I2, I1) << "copy assigned iterator is a 
diff erent position";
+  }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, MoveAssign) {
+  std::vector<int> V({0});
+
+  using FuncT = GetFuncT<TypeParam>;
+  using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+  EXPECT_TRUE(std::is_move_assignable_v<IterT>);
+
+  if constexpr (std::is_move_assignable_v<IterT>) {
+    TypeParam GetCallable;
+
+    IterT I1(V.begin(), GetCallable());
+    IterT I2(V.begin(), GetCallable());
+    IterT I3(V.end(), GetCallable());
+
+    I3 = std::move(I2);
+
+    EXPECT_EQ(I2, I1) << "move assigned iterator is a 
diff erent position";
+  }
+}
+
+TYPED_TEST(MappedIteratorTestBasic, GetFunction) {
+  std::vector<int> V({0});
+
+  using FuncT = GetFuncT<TypeParam>;
+  using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+  TypeParam GetCallable;
+  IterT I(V.begin(), GetCallable());
+
+  EXPECT_EQ(I.getFunction()(200), 201);
+}
+
+TYPED_TEST(MappedIteratorTestBasic, GetCurrent) {
+  std::vector<int> V({0});
+
+  using FuncT = GetFuncT<TypeParam>;
+  using IterT = mapped_iterator<decltype(V)::iterator, FuncT>;
+
+  TypeParam GetCallable;
+  IterT I(V.begin(), GetCallable());
+
+  EXPECT_EQ(I.getCurrent(), V.begin());
+  EXPECT_EQ(std::next(I).getCurrent(), V.end());
+}
+
+TYPED_TEST(MappedIteratorTestBasic, ApplyFunctionOnDereference) {
+  std::vector<int> V({0});
+  TypeParam GetCallable;
+
+  auto I = map_iterator(V.begin(), GetCallable());
 
   EXPECT_EQ(*I, 1) << "should have applied function in dereference";
 }
@@ -28,9 +219,9 @@ TEST(MappedIteratorTest, ApplyFunctionOnArrow) {
 
   std::vector<int> V({0});
   S Y;
-  S* P = &Y;
+  S *P = &Y;
 
-  auto I = map_iterator(V.begin(), [&](int X) -> S& { return *(P + X); });
+  auto I = map_iterator(V.begin(), [&](int X) -> S & { return *(P + X); });
 
   I->Z = 42;
 
@@ -39,9 +230,9 @@ TEST(MappedIteratorTest, ApplyFunctionOnArrow) {
 
 TEST(MappedIteratorTest, FunctionPreservesReferences) {
   std::vector<int> V({1});
-  std::map<int, int> M({ {1, 1} });
+  std::map<int, int> M({{1, 1}});
 
-  auto I = map_iterator(V.begin(), [&](int X) -> int& { return M[X]; });
+  auto I = map_iterator(V.begin(), [&](int X) -> int & { return M[X]; });
   *I = 42;
 
   EXPECT_EQ(M[1], 42) << "assignment should have modified M";


        


More information about the llvm-commits mailing list