[llvm-branch-commits] [llvm] 3043e5a - ADT: Fix reference invalidation in N-element SmallVector::append and insert

Duncan P. N. Exon Smith via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 13 20:25:13 PST 2021


Author: Duncan P. N. Exon Smith
Date: 2021-01-13T20:00:44-08:00
New Revision: 3043e5a5c33c4c871f4a1dfd621a8839f9a1f0b3

URL: https://github.com/llvm/llvm-project/commit/3043e5a5c33c4c871f4a1dfd621a8839f9a1f0b3
DIFF: https://github.com/llvm/llvm-project/commit/3043e5a5c33c4c871f4a1dfd621a8839f9a1f0b3.diff

LOG: ADT: Fix reference invalidation in N-element SmallVector::append and insert

For small enough, trivially copyable `T`, take the parameter by-value in
`SmallVector::append` and `SmallVector::insert`.  Otherwise, when
growing, update the arugment appropriately.

Differential Revision: https://reviews.llvm.org/D93780

Added: 
    

Modified: 
    llvm/include/llvm/ADT/SmallVector.h
    llvm/unittests/ADT/SmallVectorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/SmallVector.h b/llvm/include/llvm/ADT/SmallVector.h
index c91075677b3f..fea8a763d48f 100644
--- a/llvm/include/llvm/ADT/SmallVector.h
+++ b/llvm/include/llvm/ADT/SmallVector.h
@@ -223,8 +223,9 @@ class SmallVectorTemplateCommon
   /// Reserve enough space to add one element, and return the updated element
   /// pointer in case it was a reference to the storage.
   template <class U>
-  static const T *reserveForAndGetAddressImpl(U *This, const T &Elt) {
-    if (LLVM_LIKELY(This->size() < This->capacity()))
+  static const T *reserveForAndGetAddressImpl(U *This, const T &Elt, size_t N) {
+    size_t NewSize = This->size() + N;
+    if (LLVM_LIKELY(NewSize <= This->capacity()))
       return &Elt;
 
     bool ReferencesStorage = false;
@@ -233,7 +234,7 @@ class SmallVectorTemplateCommon
       ReferencesStorage = true;
       Index = &Elt - This->begin();
     }
-    This->grow();
+    This->grow(NewSize);
     return ReferencesStorage ? This->begin() + Index : &Elt;
   }
 
@@ -357,14 +358,14 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> {
 
   /// Reserve enough space to add one element, and return the updated element
   /// pointer in case it was a reference to the storage.
-  const T *reserveForAndGetAddress(const T &Elt) {
-    return this->reserveForAndGetAddressImpl(this, Elt);
+  const T *reserveForAndGetAddress(const T &Elt, size_t N = 1) {
+    return this->reserveForAndGetAddressImpl(this, Elt, N);
   }
 
   /// Reserve enough space to add one element, and return the updated element
   /// pointer in case it was a reference to the storage.
-  T *reserveForAndGetAddress(T &Elt) {
-    return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt));
+  T *reserveForAndGetAddress(T &Elt, size_t N = 1) {
+    return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt, N));
   }
 
   static T &&forward_value_param(T &&V) { return std::move(V); }
@@ -483,14 +484,14 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
 
   /// Reserve enough space to add one element, and return the updated element
   /// pointer in case it was a reference to the storage.
-  const T *reserveForAndGetAddress(const T &Elt) {
-    return this->reserveForAndGetAddressImpl(this, Elt);
+  const T *reserveForAndGetAddress(const T &Elt, size_t N = 1) {
+    return this->reserveForAndGetAddressImpl(this, Elt, N);
   }
 
   /// Reserve enough space to add one element, and return the updated element
   /// pointer in case it was a reference to the storage.
-  T *reserveForAndGetAddress(T &Elt) {
-    return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt));
+  T *reserveForAndGetAddress(T &Elt, size_t N = 1) {
+    return const_cast<T *>(this->reserveForAndGetAddressImpl(this, Elt, N));
   }
 
   /// Copy \p V or return a reference, depending on \a ValueParamT.
@@ -616,12 +617,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
   }
 
   /// Append \p NumInputs copies of \p Elt to the end.
-  void append(size_type NumInputs, const T &Elt) {
-    this->assertSafeToAdd(&Elt, NumInputs);
-    if (NumInputs > this->capacity() - this->size())
-      this->grow(this->size()+NumInputs);
-
-    std::uninitialized_fill_n(this->end(), NumInputs, Elt);
+  void append(size_type NumInputs, ValueParamT Elt) {
+    const T *EltPtr = this->reserveForAndGetAddress(Elt, NumInputs);
+    std::uninitialized_fill_n(this->end(), NumInputs, *EltPtr);
     this->set_size(this->size() + NumInputs);
   }
 
@@ -732,7 +730,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
     return insert_one_impl(I, this->forward_value_param(Elt));
   }
 
-  iterator insert(iterator I, size_type NumToInsert, const T &Elt) {
+  iterator insert(iterator I, size_type NumToInsert, ValueParamT Elt) {
     // Convert iterator to elt# to avoid invalidating iterator when we reserve()
     size_t InsertElt = I - this->begin();
 
@@ -743,11 +741,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
 
     assert(this->isReferenceToStorage(I) && "Insertion iterator is out of bounds.");
 
-    // Check that adding NumToInsert elements won't invalidate Elt.
-    this->assertSafeToAdd(&Elt, NumToInsert);
-
-    // Ensure there is enough space.
-    reserve(this->size() + NumToInsert);
+    // Ensure there is enough space, and get the (maybe updated) address of
+    // Elt.
+    const T *EltPtr = this->reserveForAndGetAddress(Elt, NumToInsert);
 
     // Uninvalidate the iterator.
     I = this->begin()+InsertElt;
@@ -764,7 +760,12 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
       // Copy the existing elements that get replaced.
       std::move_backward(I, OldEnd-NumToInsert, OldEnd);
 
-      std::fill_n(I, NumToInsert, Elt);
+      // If we just moved the element we're inserting, be sure to update
+      // the reference (never happens if TakesParamByValue).
+      if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end())
+        EltPtr += NumToInsert;
+
+      std::fill_n(I, NumToInsert, *EltPtr);
       return I;
     }
 
@@ -777,11 +778,16 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
     size_t NumOverwritten = OldEnd-I;
     this->uninitialized_move(I, OldEnd, this->end()-NumOverwritten);
 
+    // If we just moved the element we're inserting, be sure to update
+    // the reference (never happens if TakesParamByValue).
+    if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end())
+      EltPtr += NumToInsert;
+
     // Replace the overwritten part.
-    std::fill_n(I, NumOverwritten, Elt);
+    std::fill_n(I, NumOverwritten, *EltPtr);
 
     // Insert the non-overwritten middle part.
-    std::uninitialized_fill_n(OldEnd, NumToInsert-NumOverwritten, Elt);
+    std::uninitialized_fill_n(OldEnd, NumToInsert - NumOverwritten, *EltPtr);
     return I;
   }
 

diff  --git a/llvm/unittests/ADT/SmallVectorTest.cpp b/llvm/unittests/ADT/SmallVectorTest.cpp
index c880a6b6c543..c236a68636d0 100644
--- a/llvm/unittests/ADT/SmallVectorTest.cpp
+++ b/llvm/unittests/ADT/SmallVectorTest.cpp
@@ -1146,9 +1146,17 @@ TYPED_TEST(SmallVectorReferenceInvalidationTest, Resize) {
 TYPED_TEST(SmallVectorReferenceInvalidationTest, Append) {
   auto &V = this->V;
   (void)V;
-#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
-  EXPECT_DEATH(V.append(1, V.back()), this->AssertionMessage);
-#endif
+  V.append(1, V.back());
+  int N = this->NumBuiltinElts(V);
+  EXPECT_EQ(N, V[N - 1]);
+
+  // Append enough more elements that V will grow again. This tests growing
+  // when already in large mode.
+  //
+  // If reference invalidation breaks in the future, sanitizers should be able
+  // to catch a use-after-free here.
+  V.append(V.capacity() - V.size() + 1, V.front());
+  EXPECT_EQ(1, V.back());
 }
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, AppendRange) {
@@ -1244,9 +1252,20 @@ TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertMoved) {
 TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertN) {
   auto &V = this->V;
   (void)V;
-#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
-  EXPECT_DEATH(V.insert(V.begin(), 2, V.back()), this->AssertionMessage);
-#endif
+
+  // Cover NumToInsert <= this->end() - I.
+  V.insert(V.begin() + 1, 1, V.back());
+  int N = this->NumBuiltinElts(V);
+  EXPECT_EQ(N, V[1]);
+
+  // Cover NumToInsert > this->end() - I, inserting enough elements that V will
+  // also grow again; V.capacity() will be more elements than necessary but
+  // it's a simple way to cover both conditions.
+  //
+  // If reference invalidation breaks in the future, sanitizers should be able
+  // to catch a use-after-free here.
+  V.insert(V.begin(), V.capacity(), V.front());
+  EXPECT_EQ(1, V.front());
 }
 
 TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertRange) {


        


More information about the llvm-branch-commits mailing list