[libcxx-commits] [libcxx] [libc++][safety] Enhance exception safety for vector assignments in reallocation scenarios (PR #117516)

Peng Liu via libcxx-commits libcxx-commits at lists.llvm.org
Mon Nov 25 08:44:27 PST 2024


https://github.com/winner245 updated https://github.com/llvm/llvm-project/pull/117516

>From 6c67b5a2e84c3b2ee44a380f7004843348abc196 Mon Sep 17 00:00:00 2001
From: Peng Liu <winner245 at hotmail.com>
Date: Sun, 24 Nov 2024 23:08:46 -0500
Subject: [PATCH] Enhance safety for vector assignment operations

---
 libcxx/include/__vector/vector.h              |  18 +-
 .../vector.cons/assign_exceptions.pass.cpp    | 207 ++++++++++++++++++
 libcxx/test/support/allocators.h              |  44 ++++
 libcxx/test/support/exception_test_helpers.h  | 114 ++++++++++
 libcxx/test/support/test_allocator.h          |  52 +++++
 5 files changed, 431 insertions(+), 4 deletions(-)
 create mode 100644 libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp
 create mode 100644 libcxx/test/support/exception_test_helpers.h

diff --git a/libcxx/include/__vector/vector.h b/libcxx/include/__vector/vector.h
index ae3ea1de61de01..ed3697990f3871 100644
--- a/libcxx/include/__vector/vector.h
+++ b/libcxx/include/__vector/vector.h
@@ -1025,9 +1025,14 @@ vector<_Tp, _Allocator>::__assign_with_size(_ForwardIterator __first, _Sentinel
       this->__destruct_at_end(__m);
     }
   } else {
+    __split_buffer<value_type, allocator_type&> __v(__recommend(__new_size), 0, __alloc_);
+    __v.__construct_at_end_with_size(__first, __new_size);
     __vdeallocate();
-    __vallocate(__recommend(__new_size));
-    __construct_at_end(__first, __last, __new_size);
+    this->__begin_ = __v.__begin_;
+    this->__end_   = __v.__end_;
+    this->__cap_   = __v.__cap_;
+    __v.__first_ = __v.__begin_ = __v.__end_ = __v.__cap_ = nullptr;
+    __annotate_new(__new_size);
   }
 }
 
@@ -1041,9 +1046,14 @@ _LIBCPP_CONSTEXPR_SINCE_CXX20 void vector<_Tp, _Allocator>::assign(size_type __n
     else
       this->__destruct_at_end(this->__begin_ + __n);
   } else {
+    __split_buffer<value_type, allocator_type&> __v(__recommend(__n), 0, __alloc_);
+    __v.__construct_at_end(__n, __u);
     __vdeallocate();
-    __vallocate(__recommend(static_cast<size_type>(__n)));
-    __construct_at_end(__n, __u);
+    this->__begin_ = __v.__begin_;
+    this->__end_   = __v.__end_;
+    this->__cap_   = __v.__cap_;
+    __v.__first_ = __v.__begin_ = __v.__end_ = __v.__cap_ = nullptr;
+    __annotate_new(__n);
   }
 }
 
diff --git a/libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp b/libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp
new file mode 100644
index 00000000000000..1aa9d6be21f200
--- /dev/null
+++ b/libcxx/test/std/containers/sequences/vector/vector.cons/assign_exceptions.pass.cpp
@@ -0,0 +1,207 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// <vector>
+
+// Check that vector assignments don't change the rhs vector when an operation throws an exception during rallocations triggered by assignments
+
+#include <cassert>
+#include <ranges>
+#include <vector>
+
+#include "allocators.h"
+#include "exception_test_helpers.h"
+#include "test_allocator.h"
+#include "test_macros.h"
+
+void test_allocation_exception() {
+#if TEST_STD_VER >= 14
+  {
+    limited_alloc_wrapper<int> alloc1 = limited_allocator<int, 100>();
+    limited_alloc_wrapper<int> alloc2 = limited_allocator<int, 200>();
+    std::vector<int, limited_alloc_wrapper<int> > v(100, alloc1);
+    std::vector<int, limited_alloc_wrapper<int> > in(200, alloc2);
+    try { // Throw in copy-assignment operator during allocation
+      v = in;
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+
+  {
+    limited_alloc_wrapper<int> alloc1 = limited_allocator<int, 100>();
+    limited_alloc_wrapper<int> alloc2 = limited_allocator<int, 200>();
+    std::vector<int, limited_alloc_wrapper<int> > v(100, alloc1);
+    std::vector<int, limited_alloc_wrapper<int> > in(200, alloc2);
+    try { // Throw in move-assignment operator during allocation
+      v = std::move(in);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+#endif
+
+#if TEST_STD_VER >= 11
+  {
+    std::vector<int, limited_allocator<int, 5> > v(5);
+    std::initializer_list<int> in{1, 2, 3, 4, 5, 6};
+    try { // Throw in operator=(initializer_list<value_type>) during allocation
+      v = in;
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 5);
+  }
+
+  {
+    std::vector<int, limited_allocator<int, 5> > v(5);
+    std::initializer_list<int> in{1, 2, 3, 4, 5, 6};
+    try { // Throw in assign(initializer_list<value_type>) during allocation
+      v.assign(in);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 5);
+  }
+#endif
+
+  {
+    std::vector<int, limited_allocator<int, 100> > v(100);
+    std::vector<int> in(101, 1);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) during allocation
+      v.assign(in.begin(), in.end());
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+
+  {
+    std::vector<int, limited_allocator<int, 100> > v(100);
+    try { // Throw in assign(size_type, const_reference) during allocation
+      v.assign(101, 1);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+
+#if TEST_STD_VER >= 23
+  {
+    std::vector<int, limited_allocator<int, 100> > v(100);
+    std::vector<int> in(101, 1);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) during allocation
+      v.assign_range(in);
+    } catch (const std::exception&) {
+    }
+    assert(v.size() == 100);
+  }
+#endif
+}
+
+void test_construction_exception() {
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> in(6, t);
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in copy-assignment operator from element type during construction
+      v = in;
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+#if TEST_STD_VER >= 11
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    NONPOCMAAllocator<throwing_t> alloc1(1);
+    NONPOCMAAllocator<throwing_t> alloc2(2);
+    std::vector<throwing_t, NONPOCMAAllocator<throwing_t> > in(6, t, alloc1);
+    std::vector<throwing_t, NONPOCMAAllocator<throwing_t> > v(3, t, alloc2);
+    try { // Throw in move-assignment operator from element type during construction
+      v = std::move(in);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::initializer_list<throwing_t> in{t, t, t, t, t, t};
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in operator=(initializer_list<value_type>) from element type during construction
+      v = in;
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::initializer_list<throwing_t> in{t, t, t, t, t, t};
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign(initializer_list<value_type>) from element type during construction
+      v.assign(in);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+#endif
+
+  {
+    std::vector<int> v(3);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) from forward iterator during construction
+      v.assign(throwing_iterator<int, std::forward_iterator_tag>(),
+               throwing_iterator<int, std::forward_iterator_tag>(6));
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> in(6, t);
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign(_ForwardIterator, _ForwardIterator) from element type during construction
+      v.assign(in.begin(), in.end());
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+
+#if TEST_STD_VER >= 23
+  {
+    int throw_after = 10;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> in(6, t);
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign_range(_Range&&) from element type during construction
+      v.assign_range(in);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+#endif
+
+  {
+    int throw_after = 4;
+    throwing_t t    = throw_after;
+    std::vector<throwing_t> v(3, t);
+    try { // Throw in assign(size_type, const_reference) from element type during construction
+      v.assign(6, t);
+    } catch (int) {
+    }
+    assert(v.size() == 3);
+  }
+}
+
+int main() {
+  test_allocation_exception();
+  test_construction_exception();
+}
diff --git a/libcxx/test/support/allocators.h b/libcxx/test/support/allocators.h
index 02436fd9c35ef1..44f0b9a9473625 100644
--- a/libcxx/test/support/allocators.h
+++ b/libcxx/test/support/allocators.h
@@ -251,6 +251,50 @@ using POCCAAllocator = MaybePOCCAAllocator<T, /*POCCAValue = */true>;
 template <class T>
 using NonPOCCAAllocator = MaybePOCCAAllocator<T, /*POCCAValue = */false>;
 
+template <class T, bool POCMAValue>
+class MaybePOCMAAllocator {
+  template <class, bool>
+  friend class MaybePOCMAAllocator;
+
+public:
+  using propagate_on_container_move_assignment = std::integral_constant<bool, POCMAValue>;
+  using value_type                             = T;
+
+  template <class U>
+  struct rebind {
+    using other = MaybePOCMAAllocator<U, POCMAValue>;
+  };
+
+  TEST_CONSTEXPR MaybePOCMAAllocator(int id) : id_(id) {}
+
+  template <class U>
+  TEST_CONSTEXPR MaybePOCMAAllocator(const MaybePOCMAAllocator<U, POCMAValue>& other) : id_(other.id_) {}
+
+  TEST_CONSTEXPR_CXX20 T* allocate(std::size_t n) { return std::allocator<T>().allocate(n); }
+
+  TEST_CONSTEXPR_CXX20 void deallocate(T* p, std::size_t n) { std::allocator<T>().deallocate(p, n); }
+
+  TEST_CONSTEXPR int id() const { return id_; }
+
+  template <class U>
+  TEST_CONSTEXPR friend bool operator==(const MaybePOCMAAllocator& lhs, const MaybePOCMAAllocator<U, POCMAValue>& rhs) {
+    return lhs.id() == rhs.id();
+  }
+
+  template <class U>
+  TEST_CONSTEXPR friend bool operator!=(const MaybePOCMAAllocator& lhs, const MaybePOCMAAllocator<U, POCMAValue>& rhs) {
+    return !(lhs == rhs);
+  }
+
+private:
+  int id_;
+};
+
+template <class T>
+using POCMAAllocator = MaybePOCMAAllocator<T, /*POCMAValue = */ true>;
+template <class T>
+using NONPOCMAAllocator = MaybePOCMAAllocator<T, /*POCMAValue = */ false>;
+
 #endif // TEST_STD_VER >= 11
 
 #endif // ALLOCATORS_H
diff --git a/libcxx/test/support/exception_test_helpers.h b/libcxx/test/support/exception_test_helpers.h
new file mode 100644
index 00000000000000..a61f690c7fc6f6
--- /dev/null
+++ b/libcxx/test/support/exception_test_helpers.h
@@ -0,0 +1,114 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef EXCEPTION_TEST_HELPERS_H
+#define EXCEPTION_TEST_HELPERS_H
+
+#include "count_new.h"
+
+struct throwing_t {
+  int* throw_after_n_ = nullptr;
+  throwing_t() { throw 0; }
+
+  throwing_t(int& throw_after_n) : throw_after_n_(&throw_after_n) {
+    if (throw_after_n == 0)
+      throw 0;
+    --throw_after_n;
+  }
+
+  throwing_t(const throwing_t& rhs) : throw_after_n_(rhs.throw_after_n_) {
+    if (throw_after_n_ == nullptr || *throw_after_n_ == 0)
+      throw 1;
+    --*throw_after_n_;
+  }
+
+  throwing_t& operator=(const throwing_t& rhs) {
+    throw_after_n_ = rhs.throw_after_n_;
+    if (throw_after_n_ == nullptr || *throw_after_n_ == 0)
+      throw 1;
+    --*throw_after_n_;
+    return *this;
+  }
+
+  friend bool operator==(const throwing_t& lhs, const throwing_t& rhs) {
+    return lhs.throw_after_n_ == rhs.throw_after_n_;
+  }
+  friend bool operator!=(const throwing_t& lhs, const throwing_t& rhs) {
+    return lhs.throw_after_n_ != rhs.throw_after_n_;
+  }
+};
+
+template <class T>
+struct throwing_allocator {
+  using value_type      = T;
+  using is_always_equal = std::false_type;
+
+  bool throw_on_copy_ = false;
+
+  throwing_allocator(bool throw_on_ctor = true, bool throw_on_copy = false) : throw_on_copy_(throw_on_copy) {
+    if (throw_on_ctor)
+      throw 0;
+  }
+
+  template <class U>
+  throwing_allocator(const throwing_allocator<U>& rhs) : throw_on_copy_(rhs.throw_on_copy_) {
+    if (throw_on_copy_)
+      throw 0;
+  }
+
+  T* allocate(std::size_t n) { return std::allocator<T>().allocate(n); }
+  void deallocate(T* ptr, std::size_t n) { std::allocator<T>().deallocate(ptr, n); }
+
+  template <class U>
+  friend bool operator==(const throwing_allocator&, const throwing_allocator<U>&) {
+    return true;
+  }
+};
+
+template <class T, class IterCat>
+struct throwing_iterator {
+  using iterator_category = IterCat;
+  using difference_type   = std::ptrdiff_t;
+  using value_type        = T;
+  using reference         = T&;
+  using pointer           = T*;
+
+  int i_;
+  T v_;
+
+  throwing_iterator(int i = 0, const T& v = T()) : i_(i), v_(v) {}
+
+  reference operator*() {
+    if (i_ == 1)
+      throw 1;
+    return v_;
+  }
+
+  friend bool operator==(const throwing_iterator& lhs, const throwing_iterator& rhs) { return lhs.i_ == rhs.i_; }
+  friend bool operator!=(const throwing_iterator& lhs, const throwing_iterator& rhs) { return lhs.i_ != rhs.i_; }
+
+  throwing_iterator& operator++() {
+    ++i_;
+    return *this;
+  }
+
+  throwing_iterator operator++(int) {
+    auto tmp = *this;
+    ++i_;
+    return tmp;
+  }
+};
+
+inline void check_new_delete_called() {
+  assert(globalMemCounter.new_called == globalMemCounter.delete_called);
+  assert(globalMemCounter.new_array_called == globalMemCounter.delete_array_called);
+  assert(globalMemCounter.aligned_new_called == globalMemCounter.aligned_delete_called);
+  assert(globalMemCounter.aligned_new_array_called == globalMemCounter.aligned_delete_array_called);
+}
+
+#endif // EXCEPTION_TEST_HELPERS_H
\ No newline at end of file
diff --git a/libcxx/test/support/test_allocator.h b/libcxx/test/support/test_allocator.h
index dcd15332ca304f..b3cbce97325e50 100644
--- a/libcxx/test/support/test_allocator.h
+++ b/libcxx/test/support/test_allocator.h
@@ -480,6 +480,58 @@ TEST_CONSTEXPR inline bool operator!=(limited_allocator<T, N> const& LHS, limite
   return !(LHS == RHS);
 }
 
+// type erasure wrapper for limited_allocator<T, N>
+template <typename T>
+class limited_alloc_wrapper { 
+public:
+  typedef T value_type;
+  typedef value_type* pointer;
+  typedef const value_type* const_pointer;
+  typedef value_type& reference;
+  typedef const value_type& const_reference;
+  typedef std::size_t size_type;
+  typedef std::ptrdiff_t difference_type;
+
+  template <typename Alloc>
+  limited_alloc_wrapper(Alloc&& a) : pimpl(std::make_shared<impl_type<Alloc> >(std::forward<Alloc>(a))) {}
+
+  limited_alloc_wrapper(const limited_alloc_wrapper& other) = default;
+
+  pointer allocate(std::size_t n) { return pimpl->allocate(n); }
+  void deallocate(pointer p, size_t n) { pimpl->deallocate(p, n); }
+  size_type max_size() const { return pimpl->max_size(); }
+
+private:
+  struct impl_base {
+    virtual pointer allocate(std::size_t n)      = 0;
+    virtual void deallocate(pointer p, size_t n) = 0;
+    virtual size_type max_size() const           = 0;
+    virtual ~impl_base() {}
+  };
+
+  std::shared_ptr<impl_base> pimpl;
+
+  template <typename Alloc>
+  struct impl_type : impl_base {
+    Alloc a;
+    impl_type(const Alloc& a_) : a(a_) {}
+
+    pointer allocate(std::size_t n) override { return a.allocate(n); }
+    void deallocate(pointer p, size_t n) override { a.deallocate(p, n); }
+    size_type max_size() const override { return a.max_size(); }
+  };
+
+  template <class S, class U>
+  friend bool operator==(const limited_alloc_wrapper<S>& lhs, const limited_alloc_wrapper<U>& rhs) {
+    return lhs.pimpl == rhs.pimpl;
+  }
+
+  template <class S, class U>
+  friend bool operator!=(const limited_alloc_wrapper<S>& lhs, const limited_alloc_wrapper<U>& rhs) {
+    return !(lhs == rhs);
+  }
+};
+
 // Track the "provenance" of this allocator instance: how many times was
 // select_on_container_copy_construction called in order to produce it?
 //



More information about the libcxx-commits mailing list