[libcxx-commits] [libcxx] [libc++][safety] Enhance exception safety for vector assignments in reallocation scenarios (PR #117516)

Peng Liu via libcxx-commits libcxx-commits at lists.llvm.org
Tue Nov 26 05:31:14 PST 2024


https://github.com/winner245 updated https://github.com/llvm/llvm-project/pull/117516

>From ff86ef5bc0da6686bb279ec99530354cfedb52cb Mon Sep 17 00:00:00 2001
From: Peng Liu <winner245 at hotmail.com>
Date: Sun, 24 Nov 2024 23:08:46 -0500
Subject: [PATCH] Enhance safety for vector assignment operations

---
 libcxx/include/__vector/vector.h              |  18 +-
 .../vector.cons/assign_exceptions.pass.cpp    | 209 ++++++++++++++++++
 libcxx/test/support/allocators.h              |  44 ++++
 libcxx/test/support/exception_test_helpers.h  | 114 ++++++++++
 libcxx/test/support/test_allocator.h          | 136 ++++++++----
 5 files changed, 480 insertions(+), 41 deletions(-)
 create mode 100644 libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp
 create mode 100644 libcxx/test/support/exception_test_helpers.h

diff --git a/libcxx/include/__vector/vector.h b/libcxx/include/__vector/vector.h
index ae3ea1de61de01..ed3697990f3871 100644
--- a/libcxx/include/__vector/vector.h
+++ b/libcxx/include/__vector/vector.h
@@ -1025,9 +1025,14 @@ vector<_Tp, _Allocator>::__assign_with_size(_ForwardIterator __first, _Sentinel
       this->__destruct_at_end(__m);
     }
   } else {
+    __split_buffer<value_type, allocator_type&> __v(__recommend(__new_size), 0, __alloc_);
+    __v.__construct_at_end_with_size(__first, __new_size);
     __vdeallocate();
-    __vallocate(__recommend(__new_size));
-    __construct_at_end(__first, __last, __new_size);
+    this->__begin_ = __v.__begin_;
+    this->__end_   = __v.__end_;
+    this->__cap_   = __v.__cap_;
+    __v.__first_ = __v.__begin_ = __v.__end_ = __v.__cap_ = nullptr;
+    __annotate_new(__new_size);
   }
 }
 
@@ -1041,9 +1046,14 @@ _LIBCPP_CONSTEXPR_SINCE_CXX20 void vector<_Tp, _Allocator>::assign(size_type __n
     else
       this->__destruct_at_end(this->__begin_ + __n);
   } else {
+    __split_buffer<value_type, allocator_type&> __v(__recommend(__n), 0, __alloc_);
+    __v.__construct_at_end(__n, __u);
     __vdeallocate();
-    __vallocate(__recommend(static_cast<size_type>(__n)));
-    __construct_at_end(__n, __u);
+    this->__begin_ = __v.__begin_;
+    this->__end_   = __v.__end_;
+    this->__cap_   = __v.__cap_;
+    __v.__first_ = __v.__begin_ = __v.__end_ = __v.__cap_ = nullptr;
+    __annotate_new(__n);
   }
 }
 
diff --git a/libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp b/libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp
new file mode 100644
index 00000000000000..275c2df5852ede
--- /dev/null
+++ b/libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp
@@ -0,0 +1,209 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: no-exceptions
+
+// <vector>
+
+// Check that vector assignments do not alter the RHS vector when an exception is thrown during reallocations triggered by assignments.
+
+#include <cassert>
+#include <ranges>
+#include <vector>
+
+#include "allocators.h"
+#include "exception_test_helpers.h"
+#include "test_allocator.h"
+#include "test_macros.h"
+
+void test_allocation_exception() {
+#if TEST_STD_VER >= 14
+  {
+    limited_alloc_wrapper<int> alloc1 = limited_allocator<int, 100>();
+    limited_alloc_wrapper<int> alloc2 = limited_allocator<int, 200>();
+    std::vector<int, limited_alloc_wrapper<int> > v(100, alloc1);
+    std::vector<int, limited_alloc_wrapper<int> > in(200, alloc2);
+    try { // Throw in copy-assignment operator during allocation
+      v = in;
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+
+  {
+    limited_alloc_wrapper<int> alloc1 = limited_allocator<int, 100>();
+    limited_alloc_wrapper<int> alloc2 = limited_allocator<int, 200>();
+    std::vector<int, limited_alloc_wrapper<int> > v(100, alloc1);
+    std::vector<int, limited_alloc_wrapper<int> > in(200, alloc2);
+    try { // Throw in move-assignment operator during allocation
+      v = std::move(in);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+#endif
+
+#if TEST_STD_VER >= 11
+  {
+    std::vector<int, limited_allocator<int, 5> > v(5);
+    std::initializer_list<int> in{1, 2, 3, 4, 5, 6};
+    try { // Throw in operator=(initializer_list<value_type>) during allocation
+      v = in;
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 5);
+  }
+
+  {
+    std::vector<int, limited_allocator<int, 5> > v(5);
+    std::initializer_list<int> in{1, 2, 3, 4, 5, 6};
+    try { // Throw in assign(initializer_list<value_type>) during allocation
+      v.assign(in);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 5);
+  }
+#endif
+
+  {
+    std::vector<int, limited_allocator<int, 100> > v(100);
+    std::vector<int> in(101, 1);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) during allocation
+      v.assign(in.begin(), in.end());
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+
+  {
+    std::vector<int, limited_allocator<int, 100> > v(100);
+    try { // Throw in assign(size_type, const_reference) during allocation
+      v.assign(101, 1);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+
+#if TEST_STD_VER >= 23
+  {
+    std::vector<int, limited_allocator<int, 100> > v(100);
+    std::vector<int> in(101, 1);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) during allocation
+      v.assign_range(in);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+#endif
+}
+
+void test_construction_exception() {
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> in(6, t);
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in copy-assignment operator from element type during construction
+      v = in;
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+#if TEST_STD_VER >= 11
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    NONPOCMAAllocator<throwing_t> alloc1(1);
+    NONPOCMAAllocator<throwing_t> alloc2(2);
+    std::vector<throwing_t, NONPOCMAAllocator<throwing_t> > in(6, t, alloc1);
+    std::vector<throwing_t, NONPOCMAAllocator<throwing_t> > v(3, t, alloc2);
+    try { // Throw in move-assignment operator from element type during construction
+      v = std::move(in);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::initializer_list<throwing_t> in{t, t, t, t, t, t};
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in operator=(initializer_list<value_type>) from element type during construction
+      v = in;
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::initializer_list<throwing_t> in{t, t, t, t, t, t};
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign(initializer_list<value_type>) from element type during construction
+      v.assign(in);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+#endif
+
+  {
+    std::vector<int> v(3);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) from forward iterator during construction
+      v.assign(throwing_iterator<int, std::forward_iterator_tag>(),
+               throwing_iterator<int, std::forward_iterator_tag>(6));
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> in(6, t);
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) from element type during construction
+      v.assign(in.begin(), in.end());
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+#if TEST_STD_VER >= 23
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> in(6, t);
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign_range(_Range&&) from element type during construction
+      v.assign_range(in);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+#endif
+
+  {
+    int throw_after = 4;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign(size_type, const_reference) from element type during construction
+      v.assign(6, t);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+}
+
+int main() {
+  test_allocation_exception();
+  test_construction_exception();
+}
diff --git a/libcxx/test/support/allocators.h b/libcxx/test/support/allocators.h
index 02436fd9c35ef1..44f0b9a9473625 100644
--- a/libcxx/test/support/allocators.h
+++ b/libcxx/test/support/allocators.h
@@ -251,6 +251,50 @@ using POCCAAllocator = MaybePOCCAAllocator<T, /*POCCAValue = */true>;
 template <class T>
 using NonPOCCAAllocator = MaybePOCCAAllocator<T, /*POCCAValue = */false>;
 
+template <class T, bool POCMAValue>
+class MaybePOCMAAllocator {
+  template <class, bool>
+  friend class MaybePOCMAAllocator;
+
+public:
+  using propagate_on_container_move_assignment = std::integral_constant<bool, POCMAValue>;
+  using value_type                             = T;
+
+  template <class U>
+  struct rebind {
+    using other = MaybePOCMAAllocator<U, POCMAValue>;
+  };
+
+  TEST_CONSTEXPR MaybePOCMAAllocator(int id) : id_(id) {}
+
+  template <class U>
+  TEST_CONSTEXPR MaybePOCMAAllocator(const MaybePOCMAAllocator<U, POCMAValue>& other) : id_(other.id_) {}
+
+  TEST_CONSTEXPR_CXX20 T* allocate(std::size_t n) { return std::allocator<T>().allocate(n); }
+
+  TEST_CONSTEXPR_CXX20 void deallocate(T* p, std::size_t n) { std::allocator<T>().deallocate(p, n); }
+
+  TEST_CONSTEXPR int id() const { return id_; }
+
+  template <class U>
+  TEST_CONSTEXPR friend bool operator==(const MaybePOCMAAllocator& lhs, const MaybePOCMAAllocator<U, POCMAValue>& rhs) {
+    return lhs.id() == rhs.id();
+  }
+
+  template <class U>
+  TEST_CONSTEXPR friend bool operator!=(const MaybePOCMAAllocator& lhs, const MaybePOCMAAllocator<U, POCMAValue>& rhs) {
+    return !(lhs == rhs);
+  }
+
+private:
+  int id_;
+};
+
+template <class T>
+using POCMAAllocator = MaybePOCMAAllocator<T, /*POCMAValue = */ true>;
+template <class T>
+using NONPOCMAAllocator = MaybePOCMAAllocator<T, /*POCMAValue = */ false>;
+
 #endif // TEST_STD_VER >= 11
 
 #endif // ALLOCATORS_H
diff --git a/libcxx/test/support/exception_test_helpers.h b/libcxx/test/support/exception_test_helpers.h
new file mode 100644
index 00000000000000..a61f690c7fc6f6
--- /dev/null
+++ b/libcxx/test/support/exception_test_helpers.h
@@ -0,0 +1,114 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef EXCEPTION_TEST_HELPERS_H
+#define EXCEPTION_TEST_HELPERS_H
+
+#include "count_new.h"
+
+struct throwing_t {
+  int* throw_after_n_ = nullptr;
+  throwing_t() { throw 0; }
+
+  throwing_t(int& throw_after_n) : throw_after_n_(&throw_after_n) {
+    if (throw_after_n == 0)
+      throw 0;
+    --throw_after_n;
+  }
+
+  throwing_t(const throwing_t& rhs) : throw_after_n_(rhs.throw_after_n_) {
+    if (throw_after_n_ == nullptr || *throw_after_n_ == 0)
+      throw 1;
+    --*throw_after_n_;
+  }
+
+  throwing_t& operator=(const throwing_t& rhs) {
+    throw_after_n_ = rhs.throw_after_n_;
+    if (throw_after_n_ == nullptr || *throw_after_n_ == 0)
+      throw 1;
+    --*throw_after_n_;
+    return *this;
+  }
+
+  friend bool operator==(const throwing_t& lhs, const throwing_t& rhs) {
+    return lhs.throw_after_n_ == rhs.throw_after_n_;
+  }
+  friend bool operator!=(const throwing_t& lhs, const throwing_t& rhs) {
+    return lhs.throw_after_n_ != rhs.throw_after_n_;
+  }
+};
+
+template <class T>
+struct throwing_allocator {
+  using value_type      = T;
+  using is_always_equal = std::false_type;
+
+  bool throw_on_copy_ = false;
+
+  throwing_allocator(bool throw_on_ctor = true, bool throw_on_copy = false) : throw_on_copy_(throw_on_copy) {
+    if (throw_on_ctor)
+      throw 0;
+  }
+
+  template <class U>
+  throwing_allocator(const throwing_allocator<U>& rhs) : throw_on_copy_(rhs.throw_on_copy_) {
+    if (throw_on_copy_)
+      throw 0;
+  }
+
+  T* allocate(std::size_t n) { return std::allocator<T>().allocate(n); }
+  void deallocate(T* ptr, std::size_t n) { std::allocator<T>().deallocate(ptr, n); }
+
+  template <class U>
+  friend bool operator==(const throwing_allocator&, const throwing_allocator<U>&) {
+    return true;
+  }
+};
+
+template <class T, class IterCat>
+struct throwing_iterator {
+  using iterator_category = IterCat;
+  using difference_type   = std::ptrdiff_t;
+  using value_type        = T;
+  using reference         = T&;
+  using pointer           = T*;
+
+  int i_;
+  T v_;
+
+  throwing_iterator(int i = 0, const T& v = T()) : i_(i), v_(v) {}
+
+  reference operator*() {
+    if (i_ == 1)
+      throw 1;
+    return v_;
+  }
+
+  friend bool operator==(const throwing_iterator& lhs, const throwing_iterator& rhs) { return lhs.i_ == rhs.i_; }
+  friend bool operator!=(const throwing_iterator& lhs, const throwing_iterator& rhs) { return lhs.i_ != rhs.i_; }
+
+  throwing_iterator& operator++() {
+    ++i_;
+    return *this;
+  }
+
+  throwing_iterator operator++(int) {
+    auto tmp = *this;
+    ++i_;
+    return tmp;
+  }
+};
+
+inline void check_new_delete_called() {
+  assert(globalMemCounter.new_called == globalMemCounter.delete_called);
+  assert(globalMemCounter.new_array_called == globalMemCounter.delete_array_called);
+  assert(globalMemCounter.aligned_new_called == globalMemCounter.aligned_delete_called);
+  assert(globalMemCounter.aligned_new_array_called == globalMemCounter.aligned_delete_array_called);
+}
+
+#endif // EXCEPTION_TEST_HELPERS_H
\ No newline at end of file
diff --git a/libcxx/test/support/test_allocator.h b/libcxx/test/support/test_allocator.h
index dcd15332ca304f..07e7bc26804c8b 100644
--- a/libcxx/test/support/test_allocator.h
+++ b/libcxx/test/support/test_allocator.h
@@ -27,45 +27,45 @@ TEST_CONSTEXPR_CXX20 inline typename std::allocator_traits<Alloc>::size_type all
 }
 
 struct test_allocator_statistics {
-  int time_to_throw = 0;
-  int throw_after = INT_MAX;
+  int time_to_throw   = 0;
+  int throw_after     = INT_MAX;
   int count           = 0; // the number of active instances
   int alloc_count     = 0; // the number of allocations not deallocating
   int allocated_size  = 0; // the size of allocated elements
   int construct_count = 0; // the number of times that ::construct was called
-  int destroy_count = 0; // the number of times that ::destroy was called
-  int copied = 0;
-  int moved = 0;
-  int converted = 0;
+  int destroy_count   = 0; // the number of times that ::destroy was called
+  int copied          = 0;
+  int moved           = 0;
+  int converted       = 0;
 
   TEST_CONSTEXPR_CXX14 void clear() {
     assert(count == 0 && "clearing leaking allocator data?");
-    count = 0;
-    time_to_throw = 0;
-    alloc_count = 0;
+    count           = 0;
+    time_to_throw   = 0;
+    alloc_count     = 0;
     allocated_size  = 0;
     construct_count = 0;
-    destroy_count = 0;
-    throw_after = INT_MAX;
+    destroy_count   = 0;
+    throw_after     = INT_MAX;
     clear_ctor_counters();
   }
 
   TEST_CONSTEXPR_CXX14 void clear_ctor_counters() {
-    copied = 0;
-    moved = 0;
+    copied    = 0;
+    moved     = 0;
     converted = 0;
   }
 };
 
 struct test_alloc_base {
   TEST_CONSTEXPR static const int destructed_value = -1;
-  TEST_CONSTEXPR static const int moved_value = INT_MAX;
+  TEST_CONSTEXPR static const int moved_value      = INT_MAX;
 };
 
 template <class T>
 class test_allocator {
-  int data_ = 0; // participates in equality
-  int id_ = 0;   // unique identifier, doesn't participate in equality
+  int data_                         = 0; // participates in equality
+  int id_                           = 0; // unique identifier, doesn't participate in equality
   test_allocator_statistics* stats_ = nullptr;
 
   template <class U>
@@ -95,7 +95,8 @@ class test_allocator {
   TEST_CONSTEXPR explicit test_allocator(int data) TEST_NOEXCEPT : data_(data) {}
 
   TEST_CONSTEXPR_CXX14 explicit test_allocator(int data, test_allocator_statistics* stats) TEST_NOEXCEPT
-      : data_(data), stats_(stats) {
+      : data_(data),
+        stats_(stats) {
     if (stats != nullptr)
       ++stats_->count;
   }
@@ -103,13 +104,17 @@ class test_allocator {
   TEST_CONSTEXPR explicit test_allocator(int data, int id) TEST_NOEXCEPT : data_(data), id_(id) {}
 
   TEST_CONSTEXPR_CXX14 explicit test_allocator(int data, int id, test_allocator_statistics* stats) TEST_NOEXCEPT
-      : data_(data), id_(id), stats_(stats) {
+      : data_(data),
+        id_(id),
+        stats_(stats) {
     if (stats_ != nullptr)
       ++stats_->count;
   }
 
   TEST_CONSTEXPR_CXX14 test_allocator(const test_allocator& a) TEST_NOEXCEPT
-    : data_(a.data_), id_(a.id_), stats_(a.stats_) {
+      : data_(a.data_),
+        id_(a.id_),
+        stats_(a.stats_) {
     assert(a.data_ != test_alloc_base::destructed_value && a.id_ != test_alloc_base::destructed_value &&
            "copying from destroyed allocator");
     if (stats_ != nullptr) {
@@ -130,7 +135,9 @@ class test_allocator {
 
   template <class U>
   TEST_CONSTEXPR_CXX14 test_allocator(const test_allocator<U>& a) TEST_NOEXCEPT
-      : data_(a.data_), id_(a.id_), stats_(a.stats_) {
+      : data_(a.data_),
+        id_(a.id_),
+        stats_(a.stats_) {
     if (stats_ != nullptr) {
       ++stats_->count;
       ++stats_->converted;
@@ -143,7 +150,7 @@ class test_allocator {
     if (stats_ != nullptr)
       --stats_->count;
     data_ = test_alloc_base::destructed_value;
-    id_ = test_alloc_base::destructed_value;
+    id_   = test_alloc_base::destructed_value;
   }
 
   TEST_CONSTEXPR pointer address(reference x) const { return &x; }
@@ -197,8 +204,8 @@ class test_allocator {
 
 template <>
 class test_allocator<void> {
-  int data_ = 0;
-  int id_ = 0;
+  int data_                         = 0;
+  int id_                           = 0;
   test_allocator_statistics* stats_ = nullptr;
 
   template <class U>
@@ -223,27 +230,30 @@ class test_allocator<void> {
   TEST_CONSTEXPR explicit test_allocator(int data) TEST_NOEXCEPT : data_(data) {}
 
   TEST_CONSTEXPR explicit test_allocator(int data, test_allocator_statistics* stats) TEST_NOEXCEPT
-      : data_(data), stats_(stats)
-  {}
+      : data_(data),
+        stats_(stats) {}
 
   TEST_CONSTEXPR explicit test_allocator(int data, int id) : data_(data), id_(id) {}
 
   TEST_CONSTEXPR_CXX14 explicit test_allocator(int data, int id, test_allocator_statistics* stats) TEST_NOEXCEPT
-      : data_(data), id_(id), stats_(stats)
-  {}
+      : data_(data),
+        id_(id),
+        stats_(stats) {}
 
   TEST_CONSTEXPR_CXX14 explicit test_allocator(const test_allocator& a) TEST_NOEXCEPT
-      : data_(a.data_), id_(a.id_), stats_(a.stats_)
-  {}
+      : data_(a.data_),
+        id_(a.id_),
+        stats_(a.stats_) {}
 
   template <class U>
   TEST_CONSTEXPR_CXX14 test_allocator(const test_allocator<U>& a) TEST_NOEXCEPT
-      : data_(a.data_), id_(a.id_), stats_(a.stats_)
-  {}
+      : data_(a.data_),
+        id_(a.id_),
+        stats_(a.stats_) {}
 
   TEST_CONSTEXPR_CXX20 ~test_allocator() TEST_NOEXCEPT {
     data_ = test_alloc_base::destructed_value;
-    id_ = test_alloc_base::destructed_value;
+    id_   = test_alloc_base::destructed_value;
   }
 
   TEST_CONSTEXPR int get_id() const { return id_; }
@@ -310,9 +320,9 @@ struct Tag_X {
   TEST_CONSTEXPR Tag_X(Ctor_Tag, Args&&...) {}
 
   // not DefaultConstructible, CopyConstructible or MoveConstructible.
-  Tag_X() = delete;
+  Tag_X()             = delete;
   Tag_X(const Tag_X&) = delete;
-  Tag_X(Tag_X&&) = delete;
+  Tag_X(Tag_X&&)      = delete;
 
   // CopyAssignable.
   TEST_CONSTEXPR_CXX14 Tag_X& operator=(const Tag_X&) { return *this; };
@@ -329,7 +339,7 @@ struct Tag_X {
 template <typename T>
 class TaggingAllocator {
 public:
-  using value_type = T;
+  using value_type   = T;
   TaggingAllocator() = default;
 
   template <typename U>
@@ -356,13 +366,13 @@ class TaggingAllocator {
 template <std::size_t MaxAllocs>
 struct limited_alloc_handle {
   std::size_t outstanding_ = 0;
-  void* last_alloc_ = nullptr;
+  void* last_alloc_        = nullptr;
 
   template <class T>
   TEST_CONSTEXPR_CXX20 T* allocate(std::size_t N) {
     if (N + outstanding_ > MaxAllocs)
       TEST_THROW(std::bad_alloc());
-    auto alloc = std::allocator<T>().allocate(N);
+    auto alloc  = std::allocator<T>().allocate(N);
     last_alloc_ = alloc;
     outstanding_ += N;
     return alloc;
@@ -480,6 +490,58 @@ TEST_CONSTEXPR inline bool operator!=(limited_allocator<T, N> const& LHS, limite
   return !(LHS == RHS);
 }
 
+// type erasure wrapper for limited_allocator<T, N>
+template <typename T>
+class limited_alloc_wrapper {
+public:
+  typedef T value_type;
+  typedef value_type* pointer;
+  typedef const value_type* const_pointer;
+  typedef value_type& reference;
+  typedef const value_type& const_reference;
+  typedef std::size_t size_type;
+  typedef std::ptrdiff_t difference_type;
+
+  template <typename Alloc>
+  limited_alloc_wrapper(Alloc&& a) : pimpl(std::make_shared<impl_type<Alloc> >(std::forward<Alloc>(a))) {}
+
+  limited_alloc_wrapper(const limited_alloc_wrapper& other) = default;
+
+  pointer allocate(std::size_t n) { return pimpl->allocate(n); }
+  void deallocate(pointer p, size_t n) { pimpl->deallocate(p, n); }
+  size_type max_size() const { return pimpl->max_size(); }
+
+private:
+  struct impl_base {
+    virtual pointer allocate(std::size_t n)      = 0;
+    virtual void deallocate(pointer p, size_t n) = 0;
+    virtual size_type max_size() const           = 0;
+    virtual ~impl_base() {}
+  };
+
+  std::shared_ptr<impl_base> pimpl;
+
+  template <typename Alloc>
+  struct impl_type : impl_base {
+    Alloc a;
+    impl_type(const Alloc& a_) : a(a_) {}
+
+    pointer allocate(std::size_t n) override { return a.allocate(n); }
+    void deallocate(pointer p, size_t n) override { a.deallocate(p, n); }
+    size_type max_size() const override { return a.max_size(); }
+  };
+
+  template <class S, class U>
+  friend bool operator==(const limited_alloc_wrapper<S>& lhs, const limited_alloc_wrapper<U>& rhs) {
+    return lhs.pimpl == rhs.pimpl;
+  }
+
+  template <class S, class U>
+  friend bool operator!=(const limited_alloc_wrapper<S>& lhs, const limited_alloc_wrapper<U>& rhs) {
+    return !(lhs == rhs);
+  }
+};
+
 // Track the "provenance" of this allocator instance: how many times was
 // select_on_container_copy_construction called in order to produce it?
 //



More information about the libcxx-commits mailing list