[llvm] 5abf76f - ADT: Add assertions to SmallVector::insert, etc., for reference invalidation

Duncan P. N. Exon Smith via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 18 17:36:39 PST 2020


Author: Duncan P. N. Exon Smith
Date: 2020-11-18T17:36:28-08:00
New Revision: 5abf76fbe37380874a88cc9aa02164800e4e10f3

URL: https://github.com/llvm/llvm-project/commit/5abf76fbe37380874a88cc9aa02164800e4e10f3
DIFF: https://github.com/llvm/llvm-project/commit/5abf76fbe37380874a88cc9aa02164800e4e10f3.diff

LOG: ADT: Add assertions to SmallVector::insert, etc., for reference invalidation

2c196bbc6bd897b3dcc1d87a3baac28e1e88df41 asserted that
`SmallVector::push_back` doesn't invalidate the parameter when it needs
to grow. Do the same for `resize`, `append`, `assign`, `insert`, and
`emplace_back`.

Differential Revision: https://reviews.llvm.org/D91744

Added: 
    

Modified: 
    llvm/include/llvm/ADT/SmallVector.h
    llvm/lib/MC/MCParser/MasmParser.cpp
    llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
    llvm/lib/Target/ARM/Disassembler/ARMDisassembler.cpp
    llvm/unittests/ADT/SmallVectorTest.cpp
    llvm/utils/TableGen/CodeGenSchedule.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/SmallVector.h b/llvm/include/llvm/ADT/SmallVector.h
index ccf36c91d9ac..08274ead873b 100644
--- a/llvm/include/llvm/ADT/SmallVector.h
+++ b/llvm/include/llvm/ADT/SmallVector.h
@@ -136,12 +136,56 @@ class SmallVectorTemplateCommon
     this->Size = this->Capacity = 0; // FIXME: Setting Capacity to 0 is suspect.
   }
 
-  void assertSafeToPush(const void *Elt) {
-    assert(
-        (Elt < begin() || Elt >= end() || this->size() < this->capacity()) &&
-        "Attempting to push_back to the vector an element of the vector without"
-        " enough space reserved");
-  }
+  /// Check whether Elt will be invalidated by resizing the vector to NewSize.
+  void assertSafeToReferenceAfterResize(const void *Elt, size_t NewSize) {
+    assert((Elt >= this->end() ||
+            (NewSize <= this->size()
+                 ? Elt < this->begin() + NewSize
+                 : (Elt < this->begin() || NewSize <= this->capacity()))) &&
+           "Attempting to reference an element of the vector in an operation "
+           "that invalidates it");
+  }
+
+  /// Check whether Elt will be invalidated by increasing the size of the
+  /// vector by N.
+  void assertSafeToAdd(const void *Elt, size_t N = 1) {
+    this->assertSafeToReferenceAfterResize(Elt, this->size() + N);
+  }
+
+  /// Check whether any part of the range will be invalidated by clearing.
+  void assertSafeToReferenceAfterClear(const T *From, const T *To) {
+    if (From == To)
+      return;
+    this->assertSafeToReferenceAfterResize(From, 0);
+    this->assertSafeToReferenceAfterResize(To - 1, 0);
+  }
+  template <
+      class ItTy,
+      std::enable_if_t<!std::is_same<std::remove_const_t<ItTy>, T *>::value,
+                       bool> = false>
+  void assertSafeToReferenceAfterClear(ItTy, ItTy) {}
+
+  /// Check whether any part of the range will be invalidated by growing.
+  void assertSafeToAddRange(const T *From, const T *To) {
+    if (From == To)
+      return;
+    this->assertSafeToAdd(From, To - From);
+    this->assertSafeToAdd(To - 1, To - From);
+  }
+  template <
+      class ItTy,
+      std::enable_if_t<!std::is_same<std::remove_const_t<ItTy>, T *>::value,
+                       bool> = false>
+  void assertSafeToAddRange(ItTy, ItTy) {}
+
+  /// Check whether any argument will be invalidated by growing for
+  /// emplace_back.
+  template <class ArgType1, class... ArgTypes>
+  void assertSafeToEmplace(ArgType1 &Arg1, ArgTypes &... Args) {
+    this->assertSafeToAdd(&Arg1);
+    this->assertSafeToEmplace(Args...);
+  }
+  void assertSafeToEmplace() {}
 
 public:
   using size_type = size_t;
@@ -258,7 +302,7 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> {
 
 public:
   void push_back(const T &Elt) {
-    this->assertSafeToPush(&Elt);
+    this->assertSafeToAdd(&Elt);
     if (LLVM_UNLIKELY(this->size() >= this->capacity()))
       this->grow();
     ::new ((void*) this->end()) T(Elt);
@@ -266,7 +310,7 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> {
   }
 
   void push_back(T &&Elt) {
-    this->assertSafeToPush(&Elt);
+    this->assertSafeToAdd(&Elt);
     if (LLVM_UNLIKELY(this->size() >= this->capacity()))
       this->grow();
     ::new ((void*) this->end()) T(::std::move(Elt));
@@ -362,7 +406,7 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
 
 public:
   void push_back(const T &Elt) {
-    this->assertSafeToPush(&Elt);
+    this->assertSafeToAdd(&Elt);
     if (LLVM_UNLIKELY(this->size() >= this->capacity()))
       this->grow();
     memcpy(reinterpret_cast<void *>(this->end()), &Elt, sizeof(T));
@@ -418,6 +462,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
   }
 
   void resize(size_type N, const T &NV) {
+    this->assertSafeToReferenceAfterResize(&NV, N);
     if (N < this->size()) {
       this->destroy_range(this->begin()+N, this->end());
       this->set_size(N);
@@ -454,6 +499,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
                 typename std::iterator_traits<in_iter>::iterator_category,
                 std::input_iterator_tag>::value>>
   void append(in_iter in_start, in_iter in_end) {
+    this->assertSafeToAddRange(in_start, in_end);
     size_type NumInputs = std::distance(in_start, in_end);
     if (NumInputs > this->capacity() - this->size())
       this->grow(this->size()+NumInputs);
@@ -464,6 +510,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
 
   /// Append \p NumInputs copies of \p Elt to the end.
   void append(size_type NumInputs, const T &Elt) {
+    this->assertSafeToAdd(&Elt, NumInputs);
     if (NumInputs > this->capacity() - this->size())
       this->grow(this->size()+NumInputs);
 
@@ -479,6 +526,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
   // re-initializing them - for all assign(...) variants.
 
   void assign(size_type NumElts, const T &Elt) {
+    this->assertSafeToReferenceAfterResize(&Elt, 0);
     clear();
     if (this->capacity() < NumElts)
       this->grow(NumElts);
@@ -491,6 +539,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
                 typename std::iterator_traits<in_iter>::iterator_category,
                 std::input_iterator_tag>::value>>
   void assign(in_iter in_start, in_iter in_end) {
+    this->assertSafeToReferenceAfterClear(in_start, in_end);
     clear();
     append(in_start, in_end);
   }
@@ -543,6 +592,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
     assert(I >= this->begin() && "Insertion iterator is out of bounds.");
     assert(I <= this->end() && "Inserting past the end of the vector.");
 
+    // Check that adding an element won't invalidate Elt.
+    this->assertSafeToAdd(&Elt);
+
     if (this->size() >= this->capacity()) {
       size_t EltNo = I-this->begin();
       this->grow();
@@ -583,6 +635,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
     assert(I >= this->begin() && "Insertion iterator is out of bounds.");
     assert(I <= this->end() && "Inserting past the end of the vector.");
 
+    // Check that adding NumToInsert elements won't invalidate Elt.
+    this->assertSafeToAdd(&Elt, NumToInsert);
+
     // Ensure there is enough space.
     reserve(this->size() + NumToInsert);
 
@@ -638,6 +693,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
     assert(I >= this->begin() && "Insertion iterator is out of bounds.");
     assert(I <= this->end() && "Inserting past the end of the vector.");
 
+    // Check that the reserve that follows doesn't invalidate the iterators.
+    this->assertSafeToAddRange(From, To);
+
     size_t NumToInsert = std::distance(From, To);
 
     // Ensure there is enough space.
@@ -687,6 +745,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase<T> {
   }
 
   template <typename... ArgTypes> reference emplace_back(ArgTypes &&... Args) {
+    this->assertSafeToEmplace(Args...);
     if (LLVM_UNLIKELY(this->size() >= this->capacity()))
       this->grow();
     ::new ((void *)this->end()) T(std::forward<ArgTypes>(Args)...);

diff  --git a/llvm/lib/MC/MCParser/MasmParser.cpp b/llvm/lib/MC/MCParser/MasmParser.cpp
index 33e550773545..9cdd2eb2cc93 100644
--- a/llvm/lib/MC/MCParser/MasmParser.cpp
+++ b/llvm/lib/MC/MCParser/MasmParser.cpp
@@ -4145,6 +4145,9 @@ bool MasmParser::parseDirectiveNestedStruct(StringRef Directive,
   if (parseToken(AsmToken::EndOfStatement))
     return addErrorSuffix(" in '" + Twine(Directive) + "' directive");
 
+  // Reserve space to ensure Alignment doesn't get invalidated when
+  // StructInProgress grows.
+  StructInProgress.reserve(StructInProgress.size() + 1);
   StructInProgress.emplace_back(Name, DirKind == DK_UNION,
                                 StructInProgress.back().Alignment);
   return false;

diff  --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index b7746478d1c7..650c155c3536 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -7031,7 +7031,8 @@ void AMDGPUAsmParser::cvtVOP3(MCInst &Inst, const OperandVector &Operands,
     std::advance(it, AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2_modifiers));
     it = Inst.insert(it, MCOperand::createImm(0)); // no modifiers for src2
     ++it;
-    Inst.insert(it, Inst.getOperand(0)); // src2 = dst
+    // Copy the operand to ensure it's not invalidated when Inst grows.
+    Inst.insert(it, MCOperand(Inst.getOperand(0))); // src2 = dst
   }
 }
 

diff  --git a/llvm/lib/Target/ARM/Disassembler/ARMDisassembler.cpp b/llvm/lib/Target/ARM/Disassembler/ARMDisassembler.cpp
index ccef1ba3c830..8ea323a9ced5 100644
--- a/llvm/lib/Target/ARM/Disassembler/ARMDisassembler.cpp
+++ b/llvm/lib/Target/ARM/Disassembler/ARMDisassembler.cpp
@@ -860,7 +860,8 @@ ARMDisassembler::AddThumbPredicate(MCInst &MI) const {
         VCCPos + 2, MCOI::TIED_TO);
       assert(TiedOp >= 0 &&
              "Inactive register in vpred_r is not tied to an output!");
-      MI.insert(VCCI, MI.getOperand(TiedOp));
+      // Copy the operand to ensure it's not invalidated when MI grows.
+      MI.insert(VCCI, MCOperand(MI.getOperand(TiedOp)));
     }
   } else if (VCC != ARMVCC::None) {
     Check(S, SoftFail);

diff  --git a/llvm/unittests/ADT/SmallVectorTest.cpp b/llvm/unittests/ADT/SmallVectorTest.cpp
index 162716abe9ff..1a61d2304d51 100644
--- a/llvm/unittests/ADT/SmallVectorTest.cpp
+++ b/llvm/unittests/ADT/SmallVectorTest.cpp
@@ -252,7 +252,9 @@ TYPED_TEST(SmallVectorTest, PushPopTest) {
   this->theVector.push_back(Constructable(2));
   this->assertValuesInOrder(this->theVector, 2u, 1, 2);
 
-  // Insert at beginning
+  // Insert at beginning. Reserve space to avoid reference invalidation from
+  // this->theVector[1].
+  this->theVector.reserve(this->theVector.size() + 1);
   this->theVector.insert(this->theVector.begin(), this->theVector[1]);
   this->assertValuesInOrder(this->theVector, 3u, 2, 1, 2);
 
@@ -999,4 +1001,198 @@ TEST(SmallVectorTest, InitializerList) {
   EXPECT_TRUE(makeArrayRef(V2).equals({4, 5, 3, 2}));
 }
 
+template <class VectorT>
+class SmallVectorReferenceInvalidationTest : public SmallVectorTestBase {
+protected:
+  const char *AssertionMessage =
+      "Attempting to reference an element of the vector in an operation \" "
+      "\"that invalidates it";
+
+  VectorT V;
+
+  template <typename T, unsigned N>
+  static unsigned NumBuiltinElts(const SmallVector<T, N> &) {
+    return N;
+  }
+
+  void SetUp() override {
+    SmallVectorTestBase::SetUp();
+
+    // Fill up the small size so that insertions move the elements.
+    V.append({0, 0, 0});
+  }
+};
+
+// Test one type that's trivially copyable (int) and one that isn't
+// (Constructable) since reference invalidation may be fixed 
diff erently for
+// each.
+using SmallVectorReferenceInvalidationTestTypes =
+    ::testing::Types<SmallVector<int, 3>, SmallVector<Constructable, 3>>;
+
+TYPED_TEST_CASE(SmallVectorReferenceInvalidationTest,
+                SmallVectorReferenceInvalidationTestTypes);
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, PushBack) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.push_back(V.back()), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, PushBackMoved) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.push_back(std::move(V.back())), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, Resize) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.resize(2, V.back()), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, Append) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.append(1, V.back()), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, AppendRange) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.append(V.begin(), V.begin() + 1), this->AssertionMessage);
+
+  ASSERT_EQ(3u, this->NumBuiltinElts(V));
+  ASSERT_EQ(3u, V.size());
+  V.pop_back();
+  ASSERT_EQ(2u, V.size());
+
+  // Confirm this checks for growth when there's more than one element
+  // appended.
+  EXPECT_DEATH(V.append(V.begin(), V.end()), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, Assign) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  // Regardless of capacity, assign should never reference an internal element.
+  EXPECT_DEATH(V.assign(1, V.back()), this->AssertionMessage);
+  EXPECT_DEATH(V.assign(this->NumBuiltinElts(V), V.back()),
+               this->AssertionMessage);
+  EXPECT_DEATH(V.assign(this->NumBuiltinElts(V) + 1, V.back()),
+               this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, AssignRange) {
+  auto &V = this->V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.assign(V.begin(), V.end()), this->AssertionMessage);
+  EXPECT_DEATH(V.assign(V.begin(), V.end() - 1), this->AssertionMessage);
+#endif
+  V.assign(V.begin(), V.begin());
+  EXPECT_TRUE(V.empty());
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, Insert) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.insert(V.begin(), V.back()), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertMoved) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.insert(V.begin(), std::move(V.back())),
+               this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertN) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.insert(V.begin(), 2, V.back()), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, InsertRange) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.insert(V.begin(), V.begin(), V.begin() + 1),
+               this->AssertionMessage);
+
+  ASSERT_EQ(3u, this->NumBuiltinElts(V));
+  ASSERT_EQ(3u, V.size());
+  V.pop_back();
+  ASSERT_EQ(2u, V.size());
+
+  // Confirm this checks for growth when there's more than one element
+  // inserted.
+  EXPECT_DEATH(V.insert(V.begin(), V.begin(), V.end()), this->AssertionMessage);
+#endif
+}
+
+TYPED_TEST(SmallVectorReferenceInvalidationTest, EmplaceBack) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.emplace_back(V.back()), this->AssertionMessage);
+#endif
+}
+
+template <class VectorT>
+class SmallVectorInternalReferenceInvalidationTest
+    : public SmallVectorTestBase {
+protected:
+  const char *AssertionMessage =
+      "Attempting to reference an element of the vector in an operation \" "
+      "\"that invalidates it";
+
+  VectorT V;
+
+  template <typename T, unsigned N>
+  static unsigned NumBuiltinElts(const SmallVector<T, N> &) {
+    return N;
+  }
+
+  void SetUp() override {
+    SmallVectorTestBase::SetUp();
+
+    // Fill up the small size so that insertions move the elements.
+    V.push_back(std::make_pair(0, 0));
+  }
+};
+
+// Test pairs of the same types from SmallVectorReferenceInvalidationTestTypes.
+using SmallVectorInternalReferenceInvalidationTestTypes =
+    ::testing::Types<SmallVector<std::pair<int, int>, 1>,
+                     SmallVector<std::pair<Constructable, Constructable>, 1>>;
+
+TYPED_TEST_CASE(SmallVectorInternalReferenceInvalidationTest,
+                SmallVectorInternalReferenceInvalidationTestTypes);
+
+TYPED_TEST(SmallVectorInternalReferenceInvalidationTest, EmplaceBack) {
+  auto &V = this->V;
+  (void)V;
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(V.emplace_back(V.back().first, 0), this->AssertionMessage);
+  EXPECT_DEATH(V.emplace_back(0, V.back().second), this->AssertionMessage);
+#endif
+}
+
 } // end namespace

diff  --git a/llvm/utils/TableGen/CodeGenSchedule.cpp b/llvm/utils/TableGen/CodeGenSchedule.cpp
index 71fd4ec1e07d..6fe106e7a04e 100644
--- a/llvm/utils/TableGen/CodeGenSchedule.cpp
+++ b/llvm/utils/TableGen/CodeGenSchedule.cpp
@@ -1540,6 +1540,7 @@ pushVariant(const TransVariant &VInfo, bool IsRead) {
   if (SchedRW.IsVariadic) {
     unsigned OperIdx = RWSequences.size()-1;
     // Make N-1 copies of this transition's last sequence.
+    RWSequences.reserve(RWSequences.size() + SelectedRWs.size() - 1);
     RWSequences.insert(RWSequences.end(), SelectedRWs.size() - 1,
                        RWSequences[OperIdx]);
     // Push each of the N elements of the SelectedRWs onto a copy of the last
@@ -1625,6 +1626,7 @@ void PredTransitions::substituteVariantOperand(
       //    any transition with compatible CPU ID.
       // In such case we create new empty transition with zero (AnyCPU)
       // index.
+      TransVec.reserve(TransVec.size() + 1);
       TransVec.emplace_back(TransVec[StartIdx].PredTerm);
       TransVec.back().ReadSequences.emplace_back();
       CollectAndAddVariants(TransVec.size() - 1, SchedRW);


        


More information about the llvm-commits mailing list