[llvm-branch-commits] [lldb] ac25e86 - [lldb] Deal gracefully with concurrency in the API instrumentation.

Jonas Devlieghere via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 10 09:42:37 PST 2020


Author: Jonas Devlieghere
Date: 2020-12-10T09:37:49-08:00
New Revision: ac25e8628c443cddd841c6c91d1c9e23e88969e5

URL: https://github.com/llvm/llvm-project/commit/ac25e8628c443cddd841c6c91d1c9e23e88969e5
DIFF: https://github.com/llvm/llvm-project/commit/ac25e8628c443cddd841c6c91d1c9e23e88969e5.diff

LOG: [lldb] Deal gracefully with concurrency in the API instrumentation.

Prevent lldb from crashing when multiple threads are concurrently
accessing the SB API with reproducer capture enabled.

The API instrumentation records both the input arguments and the return
value, but it cannot block for the duration of the API call. Therefore
we introduce a sequence number that allows to to correlate the function
with its result and add locking to ensure those two parts are emitted
atomically.

Using the sequence number, we can detect situations where the return
value does not succeed the function call, in which case we print an
error saying that concurrency is not (currently) supported. In the
future we might attempt to be smarter and read ahead until we've found
the return value matching the current call.

Differential revision: https://reviews.llvm.org/D92820

Added: 
    

Modified: 
    lldb/include/lldb/Utility/ReproducerInstrumentation.h
    lldb/source/Utility/ReproducerInstrumentation.cpp
    lldb/unittests/Utility/ReproducerInstrumentationTest.cpp

Removed: 
    


################################################################################
diff  --git a/lldb/include/lldb/Utility/ReproducerInstrumentation.h b/lldb/include/lldb/Utility/ReproducerInstrumentation.h
index e4c31522c4fc..c8a98adf85c7 100644
--- a/lldb/include/lldb/Utility/ReproducerInstrumentation.h
+++ b/lldb/include/lldb/Utility/ReproducerInstrumentation.h
@@ -333,6 +333,7 @@ class Deserializer {
   }
 
   template <typename T> const T &HandleReplayResult(const T &t) {
+    CheckSequence(Deserialize<unsigned>());
     unsigned result = Deserialize<unsigned>();
     if (is_trivially_serializable<T>::value)
       return t;
@@ -342,6 +343,7 @@ class Deserializer {
 
   /// Store the returned value in the index-to-object mapping.
   template <typename T> T &HandleReplayResult(T &t) {
+    CheckSequence(Deserialize<unsigned>());
     unsigned result = Deserialize<unsigned>();
     if (is_trivially_serializable<T>::value)
       return t;
@@ -351,6 +353,7 @@ class Deserializer {
 
   /// Store the returned value in the index-to-object mapping.
   template <typename T> T *HandleReplayResult(T *t) {
+    CheckSequence(Deserialize<unsigned>());
     unsigned result = Deserialize<unsigned>();
     if (is_trivially_serializable<T>::value)
       return t;
@@ -360,6 +363,7 @@ class Deserializer {
   /// All returned types are recorded, even when the function returns a void.
   /// The latter requires special handling.
   void HandleReplayResultVoid() {
+    CheckSequence(Deserialize<unsigned>());
     unsigned result = Deserialize<unsigned>();
     assert(result == 0);
     (void)result;
@@ -369,6 +373,10 @@ class Deserializer {
     return m_index_to_object.GetAllObjects();
   }
 
+  void SetExpectedSequence(unsigned sequence) {
+    m_expected_sequence = sequence;
+  }
+
 private:
   template <typename T> T Read(ValueTag) {
     assert(HasData(sizeof(T)));
@@ -410,11 +418,17 @@ class Deserializer {
     return *(new UnderlyingT(Deserialize<UnderlyingT>()));
   }
 
+  /// Verify that the given sequence number matches what we expect.
+  void CheckSequence(unsigned sequence);
+
   /// Mapping of indices to objects.
   IndexToObject m_index_to_object;
 
   /// Buffer containing the serialized data.
   llvm::StringRef m_buffer;
+
+  /// The result's expected sequence number.
+  llvm::Optional<unsigned> m_expected_sequence;
 };
 
 /// Partial specialization for C-style strings. We read the string value
@@ -745,12 +759,15 @@ class Recorder {
     if (!ShouldCapture())
       return;
 
+    std::lock_guard<std::mutex> lock(g_mutex);
+    unsigned sequence = GetSequenceNumber();
     unsigned id = registry.GetID(uintptr_t(f));
 
 #ifdef LLDB_REPRO_INSTR_TRACE
     Log(id);
 #endif
 
+    serializer.SerializeAll(sequence);
     serializer.SerializeAll(id);
     serializer.SerializeAll(args...);
 
@@ -758,6 +775,7 @@ class Recorder {
             typename std::remove_reference<Result>::type>::type>::value) {
       m_result_recorded = false;
     } else {
+      serializer.SerializeAll(sequence);
       serializer.SerializeAll(0);
       m_result_recorded = true;
     }
@@ -771,16 +789,20 @@ class Recorder {
     if (!ShouldCapture())
       return;
 
+    std::lock_guard<std::mutex> lock(g_mutex);
+    unsigned sequence = GetSequenceNumber();
     unsigned id = registry.GetID(uintptr_t(f));
 
 #ifdef LLDB_REPRO_INSTR_TRACE
     Log(id);
 #endif
 
+    serializer.SerializeAll(sequence);
     serializer.SerializeAll(id);
     serializer.SerializeAll(args...);
 
     // Record result.
+    serializer.SerializeAll(sequence);
     serializer.SerializeAll(0);
     m_result_recorded = true;
   }
@@ -806,7 +828,9 @@ class Recorder {
     if (update_boundary)
       UpdateBoundary();
     if (m_serializer && ShouldCapture()) {
+      std::lock_guard<std::mutex> lock(g_mutex);
       assert(!m_result_recorded);
+      m_serializer->SerializeAll(GetSequenceNumber());
       m_serializer->SerializeAll(r);
       m_result_recorded = true;
     }
@@ -816,6 +840,7 @@ class Recorder {
   template <typename Result, typename T>
   Result Replay(Deserializer &deserializer, Registry &registry, uintptr_t addr,
                 bool update_boundary) {
+    deserializer.SetExpectedSequence(deserializer.Deserialize<unsigned>());
     unsigned actual_id = registry.GetID(addr);
     unsigned id = deserializer.Deserialize<unsigned>();
     registry.CheckID(id, actual_id);
@@ -826,6 +851,7 @@ class Recorder {
   }
 
   void Replay(Deserializer &deserializer, Registry &registry, uintptr_t addr) {
+    deserializer.SetExpectedSequence(deserializer.Deserialize<unsigned>());
     unsigned actual_id = registry.GetID(addr);
     unsigned id = deserializer.Deserialize<unsigned>();
     registry.CheckID(id, actual_id);
@@ -846,6 +872,9 @@ class Recorder {
   static void PrivateThread() { g_global_boundary = true; }
 
 private:
+  static unsigned GetNextSequenceNumber() { return g_sequence++; }
+  unsigned GetSequenceNumber() const;
+
   template <typename T> friend struct replay;
   void UpdateBoundary() {
     if (m_local_boundary)
@@ -871,8 +900,17 @@ class Recorder {
   /// Whether the return value was recorded explicitly.
   bool m_result_recorded;
 
+  /// The sequence number for this pair of function and result.
+  unsigned m_sequence;
+
   /// Whether we're currently across the API boundary.
   static thread_local bool g_global_boundary;
+
+  /// Global mutex to protect concurrent access.
+  static std::mutex g_mutex;
+
+  /// Unique, monotonically increasing sequence number.
+  static std::atomic<unsigned> g_sequence;
 };
 
 /// To be used as the "Runtime ID" of a constructor. It also invokes the
@@ -1014,6 +1052,7 @@ struct invoke_char_ptr<Result (Class::*)(Args...) const> {
 
     static Result replay(Recorder &recorder, Deserializer &deserializer,
                          Registry &registry, char *str) {
+      deserializer.SetExpectedSequence(deserializer.Deserialize<unsigned>());
       deserializer.Deserialize<unsigned>();
       Class *c = deserializer.Deserialize<Class *>();
       deserializer.Deserialize<const char *>();
@@ -1035,6 +1074,7 @@ struct invoke_char_ptr<Result (Class::*)(Args...)> {
 
     static Result replay(Recorder &recorder, Deserializer &deserializer,
                          Registry &registry, char *str) {
+      deserializer.SetExpectedSequence(deserializer.Deserialize<unsigned>());
       deserializer.Deserialize<unsigned>();
       Class *c = deserializer.Deserialize<Class *>();
       deserializer.Deserialize<const char *>();
@@ -1055,6 +1095,7 @@ struct invoke_char_ptr<Result (*)(Args...)> {
 
     static Result replay(Recorder &recorder, Deserializer &deserializer,
                          Registry &registry, char *str) {
+      deserializer.SetExpectedSequence(deserializer.Deserialize<unsigned>());
       deserializer.Deserialize<unsigned>();
       deserializer.Deserialize<const char *>();
       size_t l = deserializer.Deserialize<size_t>();

diff  --git a/lldb/source/Utility/ReproducerInstrumentation.cpp b/lldb/source/Utility/ReproducerInstrumentation.cpp
index 626120c9d71a..b274a10c98fd 100644
--- a/lldb/source/Utility/ReproducerInstrumentation.cpp
+++ b/lldb/source/Utility/ReproducerInstrumentation.cpp
@@ -8,6 +8,7 @@
 
 #include "lldb/Utility/ReproducerInstrumentation.h"
 #include "lldb/Utility/Reproducer.h"
+#include <limits>
 #include <stdio.h>
 #include <stdlib.h>
 #include <thread>
@@ -84,6 +85,16 @@ template <> const char **Deserializer::Deserialize<const char **>() {
   return r;
 }
 
+void Deserializer::CheckSequence(unsigned sequence) {
+  if (m_expected_sequence && *m_expected_sequence != sequence)
+    llvm::report_fatal_error(
+        "The result does not match the preceding "
+        "function. This is probably the result of concurrent "
+        "use of the SB API during capture, which is currently not "
+        "supported.");
+  m_expected_sequence.reset();
+}
+
 bool Registry::Replay(const FileSpec &file) {
   auto error_or_file = llvm::MemoryBuffer::getFile(file.GetPath());
   if (auto err = error_or_file.getError())
@@ -107,6 +118,7 @@ bool Registry::Replay(Deserializer &deserializer) {
   setvbuf(stdout, nullptr, _IONBF, 0);
 
   while (deserializer.HasData(1)) {
+    unsigned sequence = deserializer.Deserialize<unsigned>();
     unsigned id = deserializer.Deserialize<unsigned>();
 
 #ifndef LLDB_REPRO_INSTR_TRACE
@@ -115,6 +127,7 @@ bool Registry::Replay(Deserializer &deserializer) {
     llvm::errs() << "Replaying " << id << ": " << GetSignature(id) << "\n";
 #endif
 
+    deserializer.SetExpectedSequence(sequence);
     GetReplayer(id)->operator()(deserializer);
   }
 
@@ -181,21 +194,24 @@ unsigned ObjectToIndex::GetIndexForObjectImpl(const void *object) {
 
 Recorder::Recorder()
     : m_serializer(nullptr), m_pretty_func(), m_pretty_args(),
-      m_local_boundary(false), m_result_recorded(true) {
+      m_local_boundary(false), m_result_recorded(true),
+      m_sequence(std::numeric_limits<unsigned>::max()) {
   if (!g_global_boundary) {
     g_global_boundary = true;
     m_local_boundary = true;
+    m_sequence = GetNextSequenceNumber();
   }
 }
 
 Recorder::Recorder(llvm::StringRef pretty_func, std::string &&pretty_args)
     : m_serializer(nullptr), m_pretty_func(pretty_func),
       m_pretty_args(pretty_args), m_local_boundary(false),
-      m_result_recorded(true) {
+      m_result_recorded(true),
+      m_sequence(std::numeric_limits<unsigned>::max()) {
   if (!g_global_boundary) {
     g_global_boundary = true;
     m_local_boundary = true;
-
+    m_sequence = GetNextSequenceNumber();
     LLDB_LOG(GetLogIfAllCategoriesSet(LIBLLDB_LOG_API), "{0} ({1})",
              m_pretty_func, m_pretty_args);
   }
@@ -206,6 +222,11 @@ Recorder::~Recorder() {
   UpdateBoundary();
 }
 
+unsigned Recorder::GetSequenceNumber() const {
+  assert(m_sequence != std::numeric_limits<unsigned>::max());
+  return m_sequence;
+}
+
 void InstrumentationData::Initialize(Serializer &serializer,
                                      Registry &registry) {
   InstanceImpl().emplace(serializer, registry);
@@ -228,3 +249,5 @@ llvm::Optional<InstrumentationData> &InstrumentationData::InstanceImpl() {
 }
 
 thread_local bool lldb_private::repro::Recorder::g_global_boundary = false;
+std::atomic<unsigned> lldb_private::repro::Recorder::g_sequence;
+std::mutex lldb_private::repro::Recorder::g_mutex;

diff  --git a/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp b/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp
index 1ed00a77249f..e9f6fcf34e17 100644
--- a/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp
+++ b/lldb/unittests/Utility/ReproducerInstrumentationTest.cpp
@@ -576,8 +576,11 @@ TEST(SerializationRountripTest, SerializeDeserializeObjectPointer) {
   std::string str;
   llvm::raw_string_ostream os(str);
 
+  unsigned sequence = 123;
+
   Serializer serializer(os);
-  serializer.SerializeAll(static_cast<unsigned>(1), static_cast<unsigned>(2));
+  serializer.SerializeAll(sequence, static_cast<unsigned>(1));
+  serializer.SerializeAll(sequence, static_cast<unsigned>(2));
   serializer.SerializeAll(&foo, &bar);
 
   llvm::StringRef buffer(os.str());
@@ -597,8 +600,11 @@ TEST(SerializationRountripTest, SerializeDeserializeObjectReference) {
   std::string str;
   llvm::raw_string_ostream os(str);
 
+  unsigned sequence = 123;
+
   Serializer serializer(os);
-  serializer.SerializeAll(static_cast<unsigned>(1), static_cast<unsigned>(2));
+  serializer.SerializeAll(sequence, static_cast<unsigned>(1));
+  serializer.SerializeAll(sequence, static_cast<unsigned>(2));
   serializer.SerializeAll(foo, bar);
 
   llvm::StringRef buffer(os.str());
@@ -1114,3 +1120,48 @@ TEST(PassiveReplayTest, InstrumentedBarPtr) {
     bar.Validate();
   }
 }
+
+TEST(RecordReplayTest, ValidSequence) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+
+  {
+    auto data = TestInstrumentationDataRAII::GetRecordingData(os);
+
+    unsigned sequence = 1;
+    int (*f)() = &lldb_private::repro::invoke<int (*)()>::method<
+        InstrumentedFoo::F>::record;
+    unsigned id = g_registry->GetID(uintptr_t(f));
+    g_serializer->SerializeAll(sequence, id);
+
+    unsigned result = 0;
+    g_serializer->SerializeAll(sequence, result);
+  }
+
+  TestingRegistry registry;
+  Deserializer deserializer(os.str());
+  registry.Replay(deserializer);
+}
+
+TEST(RecordReplayTest, InvalidSequence) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+
+  {
+    auto data = TestInstrumentationDataRAII::GetRecordingData(os);
+
+    unsigned sequence = 1;
+    int (*f)() = &lldb_private::repro::invoke<int (*)()>::method<
+        InstrumentedFoo::F>::record;
+    unsigned id = g_registry->GetID(uintptr_t(f));
+    g_serializer->SerializeAll(sequence, id);
+
+    unsigned result = 0;
+    unsigned invalid_sequence = 2;
+    g_serializer->SerializeAll(invalid_sequence, result);
+  }
+
+  TestingRegistry registry;
+  Deserializer deserializer(os.str());
+  EXPECT_DEATH(registry.Replay(deserializer), "");
+}


        


More information about the llvm-branch-commits mailing list