[llvm-branch-commits] [llvm] 9abac60 - ADT: Fix reference invalidation in SmallVector::push_back and single-element insert

Duncan P. N. Exon Smith via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 13 19:03:40 PST 2021


Author: Duncan P. N. Exon Smith
Date: 2021-01-13T18:58:24-08:00
New Revision: 9abac60309006db00eca0af406c2e16bef26807c

URL: https://github.com/llvm/llvm-project/commit/9abac60309006db00eca0af406c2e16bef26807c
DIFF: https://github.com/llvm/llvm-project/commit/9abac60309006db00eca0af406c2e16bef26807c.diff

LOG: ADT: Fix reference invalidation in SmallVector::push_back and single-element insert

For small enough, trivially copyable `T`, take the argument by value in
`SmallVector::push_back` and copy it when forwarding to
`SmallVector::insert_one_impl`. Otherwise, when growing, update the
argument appropriately.

Differential Revision: https://reviews.llvm.org/D93779

Added: 
    

Modified: 
    llvm/include/llvm/ADT/SmallVector.h
    llvm/unittests/ADT/SmallVectorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/SmallVector.h b/llvm/include/llvm/ADT/SmallVector.h
index 803588143d81..f5293970aa9f 100644
--- a/llvm/include/llvm/ADT/SmallVector.h
+++ b/llvm/include/llvm/ADT/SmallVector.h
@@ -220,6 +220,23 @@ class SmallVectorTemplateCommon
   }
   void assertSafeToEmplace() {}
 
+  /// Reserve enough space to add one element, and return the updated element
+  /// pointer in case it was a reference to the storage.
+  template <class U>
+  static const T *reserveForAndGetAddressImpl(U *This, const T &Elt) {
+    if (LLVM_LIKELY(This->size() < This->capacity()))
+      return &Elt;
+
+    bool ReferencesStorage = false;
+    int64_t Index = -1;
+    if (LLVM_UNLIKELY(This->isReferenceToStorage(&Elt))) {
+      ReferencesStorage = true;
+      Index = &Elt - This->begin();
+    }
+    This->grow();
+    return ReferencesStorage ? This->begin() + Index : &Elt;
+  }
+
 public:
   using size_type = size_t;
   using 
diff erence_type = ptr
diff _t;
@@ -303,7 +320,12 @@ template <typename T, bool = (is_trivially_copy_constructible<T>::value) &&
                              (is_trivially_move_constructible<T>::value) &&
                              std::is_trivially_destructible<T>::value>
 class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> {
+  friend class SmallVectorTemplateCommon<T>;
+
 protected:
+  static constexpr bool TakesParamByValue = false;
+  using ValueParamT = const T &;
+
   SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon<T>(Size) {}
 
   static void destroy_range(T *S, T *E) {
@@ -333,20 +355,28 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> {
   /// element, or MinSize more elements if specified.
   void grow(size_t MinSize = 0);
 
+  /// Reserve enough space to add one element, and return the updated element
+  /// pointer in case it was a reference to the storage.
+  const T *reserveForAndGetAddress(const T &Elt) {
+    return this->reserveForAndGetAddressImpl(this, Elt);
+  }
+
+  /// Reserve enough space to add one element, and return the updated element
+  /// pointer in case it was a reference to the storage.
+  T *reserveForAndGetAddress(T &Elt) {
+    return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt));
+  }
+
 public:
   void push_back(const T &Elt) {
-    this->assertSafeToAdd(&Elt);
-    if (LLVM_UNLIKELY(this->size() >= this->capacity()))
-      this->grow();
-    ::new ((void*) this->end()) T(Elt);
+    const T *EltPtr = reserveForAndGetAddress(Elt);
+    ::new ((void *)this->end()) T(*EltPtr);
     this->set_size(this->size() + 1);
   }
 
   void push_back(T &&Elt) {
-    this->assertSafeToAdd(&Elt);
-    if (LLVM_UNLIKELY(this->size() >= this->capacity()))
-      this->grow();
-    ::new ((void*) this->end()) T(::std::move(Elt));
+    T *EltPtr = reserveForAndGetAddress(Elt);
+    ::new ((void *)this->end()) T(::std::move(*EltPtr));
     this->set_size(this->size() + 1);
   }
 
@@ -396,7 +426,18 @@ void SmallVectorTemplateBase<T, TriviallyCopyable>::grow(size_t MinSize) {
 /// skipping destruction.
 template <typename T>
 class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
+  friend class SmallVectorTemplateCommon<T>;
+
 protected:
+  /// True if it's cheap enough to take parameters by value. Doing so avoids
+  /// overhead related to mitigations for reference invalidation.
+  static constexpr bool TakesParamByValue = sizeof(T) <= 2 * sizeof(void *);
+
+  /// Either const T& or T, depending on whether it's cheap enough to take
+  /// parameters by value.
+  using ValueParamT =
+      typename std::conditional<TakesParamByValue, T, const T &>::type;
+
   SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon<T>(Size) {}
 
   // No need to do a destroy loop for POD's.
@@ -437,12 +478,22 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
   /// least one more element or MinSize if specified.
   void grow(size_t MinSize = 0) { this->grow_pod(MinSize, sizeof(T)); }
 
+  /// Reserve enough space to add one element, and return the updated element
+  /// pointer in case it was a reference to the storage.
+  const T *reserveForAndGetAddress(const T &Elt) {
+    return this->reserveForAndGetAddressImpl(this, Elt);
+  }
+
+  /// Reserve enough space to add one element, and return the updated element
+  /// pointer in case it was a reference to the storage.
+  T *reserveForAndGetAddress(T &Elt) {
+    return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt));
+  }
+
 public:
-  void push_back(const T &Elt) {
-    this->assertSafeToAdd(&Elt);
-    if (LLVM_UNLIKELY(this->size() >= this->capacity()))
-      this->grow();
-    memcpy(reinterpret_cast<void *>(this->end()), &Elt, sizeof(T));
+  void push_back(ValueParamT Elt) {
+    const T *EltPtr = reserveForAndGetAddress(Elt);
+    memcpy(reinterpret_cast<void *>(this->end()), EltPtr, sizeof(T));
     this->set_size(this->size() + 1);
   }
 
@@ -462,6 +513,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
   using size_type = typename SuperClass::size_type;
 
 protected:
+  using SmallVectorTemplateBase<T>::TakesParamByValue;
+  using ValueParamT = typename SuperClass::ValueParamT;
+
   // Default ctor - Initialize to empty.
   explicit SmallVectorImpl(unsigned N)
       : SmallVectorTemplateBase<T>(N) {}
@@ -628,6 +682,12 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
 
 private:
   template <class ArgType> iterator insert_one_impl(iterator I, ArgType &&Elt) {
+    // Callers ensure that ArgType is derived from T.
+    static_assert(
+        std::is_same<std::remove_const_t<std::remove_reference_t<ArgType>>,
+                     T>::value,
+        "ArgType must be derived from T!");
+
     if (I == this->end()) {  // Important special case for empty vector.
       this->push_back(::std::forward<ArgType>(Elt));
       return this->end()-1;
@@ -635,14 +695,11 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
 
     assert(this->isReferenceToStorage(I) && "Insertion iterator is out of bounds.");
 
-    // Check that adding an element won't invalidate Elt.
-    this->assertSafeToAdd(&Elt);
-
-    if (this->size() >= this->capacity()) {
-      size_t EltNo = I-this->begin();
-      this->grow();
-      I = this->begin()+EltNo;
-    }
+    // Grow if necessary.
+    size_t Index = I - this->begin();
+    std::remove_reference_t<ArgType> *EltPtr =
+        this->reserveForAndGetAddress(Elt);
+    I = this->begin() + Index;
 
     ::new ((void*) this->end()) T(::std::move(this->back()));
     // Push everything else over.
@@ -650,21 +707,48 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
     this->set_size(this->size() + 1);
 
     // If we just moved the element we're inserting, be sure to update
-    // the reference.
-    std::remove_reference_t<ArgType> *EltPtr = &Elt;
-    if (this->isReferenceToRange(EltPtr, I, this->end()))
+    // the reference (never happens if TakesParamByValue).
+    static_assert(!TakesParamByValue || std::is_same<ArgType, T>::value,
+                  "ArgType must be 'T' when taking by value!");
+    if (!TakesParamByValue && this->isReferenceToRange(EltPtr, I, this->end()))
       ++EltPtr;
 
     *I = ::std::forward<ArgType>(*EltPtr);
     return I;
   }
 
+  template <
+      class ArgType,
+      std::enable_if_t<
+          std::is_same<std::remove_const_t<std::remove_reference_t<ArgType>>,
+                       T>::value &&
+              !TakesParamByValue,
+          bool> = false>
+  iterator insert_one_maybe_copy(iterator I, ArgType &&Elt) {
+    return insert_one_impl(I, std::forward<ArgType>(Elt));
+  }
+
+  template <
+      class ArgType,
+      std::enable_if_t<
+          std::is_same<std::remove_const_t<std::remove_reference_t<ArgType>>,
+                       T>::value &&
+              TakesParamByValue,
+          bool> = false>
+  iterator insert_one_maybe_copy(iterator I, ArgType &&Elt) {
+    // Copy Elt in order to mitigate reference invalidation without needing to
+    // update the pointer values in insert_one_impl.
+    return insert_one_impl(I, T(Elt));
+  }
+
 public:
   iterator insert(iterator I, T &&Elt) {
-    return insert_one_impl(I, std::move(Elt));
+    return insert_one_maybe_copy(I, std::move(Elt));
   }
 
-  iterator insert(iterator I, const T &Elt) { return insert_one_impl(I, Elt); }
+  iterator insert(iterator I, const T &Elt) {
+    return insert_one_maybe_copy(I, Elt);
+  }
 
   iterator insert(iterator I, size_type NumToInsert, const T &Elt) {
     // Convert iterator to elt# to avoid invalidating iterator when we reserve()

diff  --git a/llvm/unittests/ADT/SmallVectorTest.cpp b/llvm/unittests/ADT/SmallVectorTest.cpp
index d97ab577524f..c880a6b6c543 100644
--- a/llvm/unittests/ADT/SmallVectorTest.cpp
+++ b/llvm/unittests/ADT/SmallVectorTest.cpp
@@ -53,6 +53,7 @@ class Constructable {
 
   Constructable(Constructable && src) : constructed(true) {
     value = src.value;
+    src.value = 0;
     ++numConstructorCalls;
     ++numMoveConstructorCalls;
   }
@@ -74,6 +75,7 @@ class Constructable {
   Constructable & operator=(Constructable && src) {
     EXPECT_TRUE(constructed);
     value = src.value;
+    src.value = 0;
     ++numAssignmentCalls;
     ++numMoveAssignmentCalls;
     return *this;
@@ -1056,11 +1058,16 @@ class SmallVectorReferenceInvalidationTest : public SmallVectorTestBase {
     return N;
   }
 
+  template <class T> static bool isValueType() {
+    return std::is_same<T, typename VectorT::value_type>::value;
+  }
+
   void SetUp() override {
     SmallVectorTestBase::SetUp();
 
     // Fill up the small size so that insertions move the elements.
-    V.append({0, 0, 0});
+    for (int I = 0, E = NumBuiltinElts(V); I != E; ++I)
+      V.emplace_back(I + 1);
   }
 };
 
@@ -1074,19 +1081,54 @@ TYPED_TEST_CASE(SmallVectorReferenceInvalidationTest,
                 SmallVectorReferenceInvalidationTestTypes);
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, PushBack) {
+  // Note: setup adds [1, 2, ...] to V until it's at capacity in small mode.
   auto &V = this->V;
-  (void)V;
-#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
-  EXPECT_DEATH(V.push_back(V.back()), this->AssertionMessage);
-#endif
+  int N = this->NumBuiltinElts(V);
+
+  // Push back a reference to last element when growing from small storage.
+  V.push_back(V.back());
+  EXPECT_EQ(N, V.back());
+
+  // Check that the old value is still there (not moved away).
+  EXPECT_EQ(N, V[V.size() - 2]);
+
+  // Fill storage again.
+  V.back() = V.size();
+  while (V.size() < V.capacity())
+    V.push_back(V.size() + 1);
+
+  // Push back a reference to last element when growing from large storage.
+  V.push_back(V.back());
+  EXPECT_EQ(int(V.size()) - 1, V.back());
 }
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, PushBackMoved) {
+  // Note: setup adds [1, 2, ...] to V until it's at capacity in small mode.
   auto &V = this->V;
-  (void)V;
-#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
-  EXPECT_DEATH(V.push_back(std::move(V.back())), this->AssertionMessage);
-#endif
+  int N = this->NumBuiltinElts(V);
+
+  // Push back a reference to last element when growing from small storage.
+  V.push_back(std::move(V.back()));
+  EXPECT_EQ(N, V.back());
+  if (this->template isValueType<Constructable>()) {
+    // Check that the value was moved (not copied).
+    EXPECT_EQ(0, V[V.size() - 2]);
+  }
+
+  // Fill storage again.
+  V.back() = V.size();
+  while (V.size() < V.capacity())
+    V.push_back(V.size() + 1);
+
+  // Push back a reference to last element when growing from large storage.
+  V.push_back(std::move(V.back()));
+
+  // Check the values.
+  EXPECT_EQ(int(V.size()) - 1, V.back());
+  if (this->template isValueType<Constructable>()) {
+    // Check the value got moved out.
+    EXPECT_EQ(0, V[V.size() - 2]);
+  }
 }
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, Resize) {
@@ -1150,20 +1192,53 @@ TYPED_TEST(SmallVectorReferenceInvalidationTest, AssignRange) {
 }
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, Insert) {
+  // Note: setup adds [1, 2, ...] to V until it's at capacity in small mode.
   auto &V = this->V;
   (void)V;
-#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
-  EXPECT_DEATH(V.insert(V.begin(), V.back()), this->AssertionMessage);
-#endif
+
+  // Insert a reference to the back (not at end() or else insert delegates to
+  // push_back()), growing out of small mode. Confirm the value was copied out
+  // (moving out Constructable sets it to 0).
+  V.insert(V.begin(), V.back());
+  EXPECT_EQ(int(V.size() - 1), V.front());
+  EXPECT_EQ(int(V.size() - 1), V.back());
+
+  // Fill up the vector again.
+  while (V.size() < V.capacity())
+    V.push_back(V.size() + 1);
+
+  // Grow again from large storage to large storage.
+  V.insert(V.begin(), V.back());
+  EXPECT_EQ(int(V.size() - 1), V.front());
+  EXPECT_EQ(int(V.size() - 1), V.back());
 }
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertMoved) {
+  // Note: setup adds [1, 2, ...] to V until it's at capacity in small mode.
   auto &V = this->V;
   (void)V;
-#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
-  EXPECT_DEATH(V.insert(V.begin(), std::move(V.back())),
-               this->AssertionMessage);
-#endif
+
+  // Insert a reference to the back (not at end() or else insert delegates to
+  // push_back()), growing out of small mode. Confirm the value was copied out
+  // (moving out Constructable sets it to 0).
+  V.insert(V.begin(), std::move(V.back()));
+  EXPECT_EQ(int(V.size() - 1), V.front());
+  if (this->template isValueType<Constructable>()) {
+    // Check the value got moved out.
+    EXPECT_EQ(0, V.back());
+  }
+
+  // Fill up the vector again.
+  while (V.size() < V.capacity())
+    V.push_back(V.size() + 1);
+
+  // Grow again from large storage to large storage.
+  V.insert(V.begin(), std::move(V.back()));
+  EXPECT_EQ(int(V.size() - 1), V.front());
+  if (this->template isValueType<Constructable>()) {
+    // Check the value got moved out.
+    EXPECT_EQ(0, V.back());
+  }
 }
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertN) {


        


More information about the llvm-branch-commits mailing list