[Lldb-commits] [lldb] [lldb] Adding A new Binding helper for JSONTransport. (PR #159160)
John Harrison via lldb-commits
lldb-commits at lists.llvm.org
Tue Sep 16 17:52:27 PDT 2025
https://github.com/ashgti updated https://github.com/llvm/llvm-project/pull/159160
>From b91cad6dc7ea4c40c61a0cfe9bb45be03de32254 Mon Sep 17 00:00:00 2001
From: John Harrison <harjohn at google.com>
Date: Wed, 10 Sep 2025 10:42:56 -0700
Subject: [PATCH] [lldb] Adding A new Binding helper for JSONTransport.
This adds a new Binding helper class to allow mapping of incoming and outgoing requests / events to specific handlers.
This should make it easier to create new protocol implementations and allow us to create a relay in the lldb-mcp binary.
---
lldb/include/lldb/Host/JSONTransport.h | 387 ++++++++++++++++--
lldb/include/lldb/Protocol/MCP/Protocol.h | 8 +
lldb/include/lldb/Protocol/MCP/Server.h | 73 ++--
lldb/include/lldb/Protocol/MCP/Transport.h | 77 +++-
lldb/source/Host/common/JSONTransport.cpp | 10 +
.../Protocol/MCP/ProtocolServerMCP.cpp | 42 +-
.../Plugins/Protocol/MCP/ProtocolServerMCP.h | 20 +-
lldb/source/Protocol/MCP/Server.cpp | 211 +++-------
lldb/tools/lldb-dap/DAP.h | 6 +-
lldb/tools/lldb-dap/Protocol/ProtocolBase.h | 6 +-
lldb/tools/lldb-dap/Transport.h | 6 +-
lldb/unittests/DAP/DAPTest.cpp | 20 +-
lldb/unittests/DAP/Handler/DisconnectTest.cpp | 4 +-
lldb/unittests/DAP/TestBase.cpp | 42 +-
lldb/unittests/DAP/TestBase.h | 122 +++---
lldb/unittests/Host/JSONTransportTest.cpp | 338 +++++++++++----
.../Protocol/ProtocolMCPServerTest.cpp | 280 +++++++------
.../Host/JSONTransportTestUtilities.h | 96 ++++-
18 files changed, 1170 insertions(+), 578 deletions(-)
diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h
index 210f33edace6e..da1ae43118538 100644
--- a/lldb/include/lldb/Host/JSONTransport.h
+++ b/lldb/include/lldb/Host/JSONTransport.h
@@ -18,6 +18,7 @@
#include "lldb/Utility/IOObject.h"
#include "lldb/Utility/Status.h"
#include "lldb/lldb-forward.h"
+#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
@@ -25,8 +26,13 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
+#include <functional>
+#include <mutex>
+#include <optional>
#include <string>
#include <system_error>
+#include <type_traits>
+#include <utility>
#include <variant>
#include <vector>
@@ -50,17 +56,70 @@ class TransportUnhandledContentsError
std::string m_unhandled_contents;
};
+class InvalidParams : public llvm::ErrorInfo<InvalidParams> {
+public:
+ static char ID;
+
+ explicit InvalidParams(std::string method, std::string context)
+ : m_method(std::move(method)), m_context(std::move(context)) {}
+
+ void log(llvm::raw_ostream &OS) const override;
+ std::error_code convertToErrorCode() const override;
+
+private:
+ std::string m_method;
+ std::string m_context;
+};
+
+// Value for tracking functions that have a void param or result.
+using VoidT = std::monostate;
+
+template <typename T> using Callback = llvm::unique_function<T>;
+
+template <typename T>
+using Reply = typename std::conditional<
+ std::is_same_v<T, VoidT> == true, llvm::unique_function<void(llvm::Error)>,
+ llvm::unique_function<void(llvm::Expected<T>)>>::type;
+
+template <typename Result, typename Params>
+using OutgoingRequest = typename std::conditional<
+ std::is_same_v<Params, VoidT> == true,
+ llvm::unique_function<void(Reply<Result>)>,
+ llvm::unique_function<void(const Params &, Reply<Result>)>>::type;
+
+template <typename Params>
+using OutgoingEvent = typename std::conditional<
+ std::is_same_v<Params, VoidT> == true, llvm::unique_function<void()>,
+ llvm::unique_function<void(const Params &)>>::type;
+
+template <typename Id, typename Req>
+Req make_request(Id id, llvm::StringRef method,
+ std::optional<llvm::json::Value> params = std::nullopt);
+template <typename Req, typename Resp>
+Resp make_response(const Req &req, llvm::Error error);
+template <typename Req, typename Resp>
+Resp make_response(const Req &req, llvm::json::Value result);
+template <typename Evt>
+Evt make_event(llvm::StringRef method,
+ std::optional<llvm::json::Value> params = std::nullopt);
+template <typename Resp>
+llvm::Expected<llvm::json::Value> get_result(const Resp &resp);
+template <typename Id, typename T> Id get_id(const T &);
+template <typename T> llvm::StringRef get_method(const T &);
+template <typename T> llvm::json::Value get_params(const T &);
+
/// A transport is responsible for maintaining the connection to a client
/// application, and reading/writing structured messages to it.
///
/// Transports have limited thread safety requirements:
/// - Messages will not be sent concurrently.
/// - Messages MAY be sent while Run() is reading, or its callback is active.
-template <typename Req, typename Resp, typename Evt> class Transport {
+template <typename Id, typename Req, typename Resp, typename Evt>
+class JSONTransport {
public:
using Message = std::variant<Req, Resp, Evt>;
- virtual ~Transport() = default;
+ virtual ~JSONTransport() = default;
/// Sends an event, a message that does not require a response.
virtual llvm::Error Send(const Evt &) = 0;
@@ -90,8 +149,6 @@ template <typename Req, typename Resp, typename Evt> class Transport {
virtual void OnClosed() = 0;
};
- using MessageHandlerSP = std::shared_ptr<MessageHandler>;
-
/// RegisterMessageHandler registers the Transport with the given MainLoop and
/// handles any incoming messages using the given MessageHandler.
///
@@ -100,22 +157,302 @@ template <typename Req, typename Resp, typename Evt> class Transport {
virtual llvm::Expected<MainLoop::ReadHandleUP>
RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0;
- // FIXME: Refactor mcp::Server to not directly access log on the transport.
- // protected:
+protected:
template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) {
Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str());
}
virtual void Log(llvm::StringRef message) = 0;
+
+ /// Function object to reply to a call.
+ /// Each instance must be called exactly once, otherwise:
+ /// - the bug is logged, and (in debug mode) an assert will fire
+ /// - if there was no reply, an error reply is sent
+ /// - if there were multiple replies, only the first is sent
+ class ReplyOnce {
+ std::atomic<bool> replied = {false};
+ const Req req;
+ JSONTransport *transport; // Null when moved-from.
+ JSONTransport::MessageHandler *handler; // Null when moved-from.
+
+ public:
+ ReplyOnce(const Req req, JSONTransport *transport,
+ JSONTransport::MessageHandler *handler)
+ : req(req), transport(transport), handler(handler) {
+ assert(handler);
+ }
+ ReplyOnce(ReplyOnce &&other)
+ : replied(other.replied.load()), req(other.req),
+ transport(other.transport), handler(other.handler) {
+ other.transport = nullptr;
+ other.handler = nullptr;
+ }
+ ReplyOnce &operator=(ReplyOnce &&) = delete;
+ ReplyOnce(const ReplyOnce &) = delete;
+ ReplyOnce &operator=(const ReplyOnce &) = delete;
+
+ ~ReplyOnce() {
+ if (transport && handler && !replied) {
+ assert(false && "must reply to all calls!");
+ (*this)(make_response<Req, Resp>(
+ req, llvm::createStringError("failed to reply")));
+ }
+ }
+
+ void operator()(const Resp &resp) {
+ assert(transport && handler && "moved-from!");
+ if (replied.exchange(true)) {
+ assert(false && "must reply to each call only once!");
+ return;
+ }
+
+ if (llvm::Error error = transport->Send(resp))
+ handler->OnError(std::move(error));
+ }
+ };
+
+public:
+ class Binder;
+ using BinderUP = std::unique_ptr<Binder>;
+
+ /// Binder collects a table of functions that handle calls.
+ ///
+ /// The wrapper takes care of parsing/serializing responses.
+ class Binder : public JSONTransport::MessageHandler {
+ public:
+ explicit Binder(JSONTransport &transport)
+ : m_transport(transport), m_seq(0) {}
+
+ Binder(const Binder &) = delete;
+ Binder &operator=(const Binder &) = delete;
+
+ /// Bind a handler on transport disconnect.
+ template <typename Fn, typename... Args>
+ void disconnected(Fn &&fn, Args &&...args) {
+ m_disconnect_handler = [&, args...]() mutable {
+ std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...);
+ };
+ }
+
+ /// Bind a handler on error when communicating with the transport.
+ template <typename Fn, typename... Args>
+ void error(Fn &&fn, Args &&...args) {
+ m_error_handler = [&, args...](llvm::Error error) mutable {
+ std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...,
+ std::move(error));
+ };
+ }
+
+ template <typename T>
+ static llvm::Expected<T> parse(const llvm::json::Value &raw,
+ llvm::StringRef method) {
+ T result;
+ llvm::json::Path::Root root;
+ if (!fromJSON(raw, result, root)) {
+ // Dump the relevant parts of the broken message.
+ std::string context;
+ llvm::raw_string_ostream OS(context);
+ root.printErrorContext(raw, OS);
+ return llvm::make_error<InvalidParams>(method.str(), context);
+ }
+ return std::move(result);
+ }
+
+ /// Bind a handler for a request.
+ /// e.g. `bind("peek", &ThisModule::peek, this, std::placeholders::_1);`.
+ /// Handler should be e.g. `Expected<PeekResult> peek(const PeekParams&);`
+ /// PeekParams must be JSON parsable and PeekResult must be serializable.
+ template <typename Result, typename Params, typename Fn, typename... Args>
+ void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) {
+ assert(m_request_handlers.find(method) == m_request_handlers.end() &&
+ "request already bound");
+ if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) {
+ m_request_handlers[method] =
+ [fn,
+ args...](const Req &req,
+ llvm::unique_function<void(const Resp &)> reply) mutable {
+ llvm::Expected<Result> result = std::invoke(
+ std::forward<Fn>(fn), std::forward<Args>(args)...);
+ if (!result)
+ return reply(make_response<Req, Resp>(req, result.takeError()));
+ reply(make_response<Req, Resp>(req, toJSON(*result)));
+ };
+ } else {
+ m_request_handlers[method] =
+ [method, fn,
+ args...](const Req &req,
+ llvm::unique_function<void(const Resp &)> reply) mutable {
+ llvm::Expected<Params> params =
+ parse<Params>(get_params<Req>(req), method);
+ if (!params)
+ return reply(make_response<Req, Resp>(req, params.takeError()));
+
+ llvm::Expected<Result> result = std::invoke(
+ std::forward<Fn>(fn), std::forward<Args>(args)..., *params);
+ if (!result)
+ return reply(make_response<Req, Resp>(req, result.takeError()));
+
+ reply(make_response<Req, Resp>(req, toJSON(*result)));
+ };
+ }
+ }
+
+ /// Bind a handler for a event.
+ /// e.g. `bind("peek", &ThisModule::peek, this);`
+ /// Handler should be e.g. `void peek(const PeekParams&);`
+ /// PeekParams must be JSON parsable.
+ template <typename Params, typename Fn, typename... Args>
+ void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) {
+ assert(m_event_handlers.find(method) == m_event_handlers.end() &&
+ "event already bound");
+ if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) {
+ m_event_handlers[method] = [fn, args...](const Evt &) mutable {
+ std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...);
+ };
+ } else {
+ m_event_handlers[method] = [this, method, fn,
+ args...](const Evt &evt) mutable {
+ llvm::Expected<Params> params =
+ parse<Params>(get_params<Evt>(evt), method);
+ if (!params)
+ return OnError(params.takeError());
+ std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...,
+ *params);
+ };
+ }
+ }
+
+ /// Bind a function object to be used for outgoing requests.
+ /// e.g. `OutgoingRequest<Params, Result> Edit = bind("edit");`
+ /// Params must be JSON-serializable, Result must be parsable.
+ template <typename Result, typename Params>
+ OutgoingRequest<Result, Params> bind(llvm::StringLiteral method) {
+ if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) {
+ return [this, method](Reply<Result> fn) {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ Id id = ++m_seq;
+ Req req = make_request<Req, Resp>(id, method, std::nullopt);
+ m_pending_responses[id] = [fn = std::move(fn),
+ method](const Resp &resp) mutable {
+ llvm::Expected<llvm::json::Value> result = get_result<Resp>(resp);
+ if (!result)
+ return fn(result.takeError());
+ fn(parse<Result>(*result, method));
+ };
+ if (llvm::Error error = m_transport.Send(req))
+ OnError(std::move(error));
+ };
+ } else {
+ return [this, method](const Params ¶ms, Reply<Result> fn) {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ Id id = ++m_seq;
+ Req req =
+ make_request<Id, Req>(id, method, llvm::json::Value(params));
+ m_pending_responses[id] = [fn = std::move(fn),
+ method](const Resp &resp) mutable {
+ llvm::Expected<llvm::json::Value> result = get_result<Resp>(resp);
+ if (llvm::Error err = result.takeError())
+ return fn(std::move(err));
+ fn(parse<Result>(*result, method));
+ };
+ if (llvm::Error error = m_transport.Send(req))
+ OnError(std::move(error));
+ };
+ }
+ }
+
+ /// Bind a function object to be used for outgoing events.
+ /// e.g. `OutgoingEvent<LogParams> Log = bind("log");`
+ /// LogParams must be JSON-serializable.
+ template <typename Params>
+ OutgoingEvent<Params> bind(llvm::StringLiteral method) {
+ if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) {
+ return [this, method]() {
+ if (llvm::Error error =
+ m_transport.Send(make_event<Evt>(method, std::nullopt)))
+ OnError(std::move(error));
+ };
+ } else {
+ return [this, method](const Params ¶ms) {
+ if (llvm::Error error =
+ m_transport.Send(make_event<Evt>(method, toJSON(params))))
+ OnError(std::move(error));
+ };
+ }
+ }
+
+ void Received(const Evt &evt) override {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ auto it = m_event_handlers.find(get_method<Evt>(evt));
+ if (it == m_event_handlers.end()) {
+ OnError(llvm::createStringError(
+ llvm::formatv("no handled for event {0}", toJSON(evt))));
+ return;
+ }
+ it->second(evt);
+ }
+
+ void Received(const Req &req) override {
+ ReplyOnce reply(req, &m_transport, this);
+
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ auto it = m_request_handlers.find(get_method<Req>(req));
+ if (it == m_request_handlers.end()) {
+ reply(make_response<Req, Resp>(
+ req, llvm::createStringError("method not found")));
+ return;
+ }
+
+ it->second(req, std::move(reply));
+ }
+
+ void Received(const Resp &resp) override {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ auto it = m_pending_responses.find(get_id<Id, Resp>(resp));
+ if (it == m_pending_responses.end()) {
+ OnError(llvm::createStringError(
+ llvm::formatv("no pending request for {0}", toJSON(resp))));
+ return;
+ }
+
+ it->second(resp);
+ m_pending_responses.erase(it);
+ }
+
+ void OnError(llvm::Error err) override {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ if (m_error_handler)
+ m_error_handler(std::move(err));
+ }
+
+ void OnClosed() override {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ if (m_disconnect_handler)
+ m_disconnect_handler();
+ }
+
+ private:
+ std::recursive_mutex m_mutex;
+ JSONTransport &m_transport;
+ Id m_seq;
+ std::map<Id, Callback<void(const Resp &)>> m_pending_responses;
+ llvm::StringMap<Callback<void(const Req &, Callback<void(const Resp &)>)>>
+ m_request_handlers;
+ llvm::StringMap<Callback<void(const Evt &)>> m_event_handlers;
+ Callback<void()> m_disconnect_handler;
+ Callback<void(llvm::Error)> m_error_handler;
+ };
};
-/// A JSONTransport will encode and decode messages using JSON.
-template <typename Req, typename Resp, typename Evt>
-class JSONTransport : public Transport<Req, Resp, Evt> {
+/// A IOTransport will encode and decode messages using an IOObject like a
+/// file or a socket.
+template <typename Id, typename Req, typename Resp, typename Evt>
+class IOTransport : public JSONTransport<Id, Req, Resp, Evt> {
public:
- using Transport<Req, Resp, Evt>::Transport;
- using MessageHandler = typename Transport<Req, Resp, Evt>::MessageHandler;
+ using Message = typename JSONTransport<Id, Req, Resp, Evt>::Message;
+ using MessageHandler =
+ typename JSONTransport<Id, Req, Resp, Evt>::MessageHandler;
- JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out)
+ IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out)
: m_in(in), m_out(out) {}
llvm::Error Send(const Evt &evt) override { return Write(evt); }
@@ -127,7 +464,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
Status status;
MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject(
m_in,
- std::bind(&JSONTransport::OnRead, this, std::placeholders::_1,
+ std::bind(&IOTransport::OnRead, this, std::placeholders::_1,
std::ref(handler)),
status);
if (status.Fail()) {
@@ -140,7 +477,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
/// detail.
static constexpr size_t kReadBufferSize = 1024;
- // FIXME: Write should be protected.
+protected:
llvm::Error Write(const llvm::json::Value &message) {
this->Logv("<-- {0}", message);
std::string output = Encode(message);
@@ -148,7 +485,6 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
return m_out->Write(output.data(), bytes_written).takeError();
}
-protected:
virtual llvm::Expected<std::vector<std::string>> Parse() = 0;
virtual std::string Encode(const llvm::json::Value &message) = 0;
@@ -175,9 +511,8 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
}
for (const std::string &raw_message : *raw_messages) {
- llvm::Expected<typename Transport<Req, Resp, Evt>::Message> message =
- llvm::json::parse<typename Transport<Req, Resp, Evt>::Message>(
- raw_message);
+ llvm::Expected<Message> message =
+ llvm::json::parse<Message>(raw_message);
if (!message) {
handler.OnError(message.takeError());
return;
@@ -202,10 +537,10 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
};
/// A transport class for JSON with a HTTP header.
-template <typename Req, typename Resp, typename Evt>
-class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> {
+template <typename Id, typename Req, typename Resp, typename Evt>
+class HTTPDelimitedJSONTransport : public IOTransport<Id, Req, Resp, Evt> {
public:
- using JSONTransport<Req, Resp, Evt>::JSONTransport;
+ using IOTransport<Id, Req, Resp, Evt>::IOTransport;
protected:
/// Encodes messages based on
@@ -231,8 +566,8 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> {
for (const llvm::StringRef &header :
llvm::split(headers, kHeaderSeparator)) {
auto [key, value] = header.split(kHeaderFieldSeparator);
- // 'Content-Length' is the only meaningful key at the moment. Others are
- // ignored.
+ // 'Content-Length' is the only meaningful key at the moment. Others
+ // are ignored.
if (!key.equals_insensitive(kHeaderContentLength))
continue;
@@ -269,10 +604,10 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> {
};
/// A transport class for JSON RPC.
-template <typename Req, typename Resp, typename Evt>
-class JSONRPCTransport : public JSONTransport<Req, Resp, Evt> {
+template <typename Id, typename Req, typename Resp, typename Evt>
+class JSONRPCTransport : public IOTransport<Id, Req, Resp, Evt> {
public:
- using JSONTransport<Req, Resp, Evt>::JSONTransport;
+ using IOTransport<Id, Req, Resp, Evt>::IOTransport;
protected:
std::string Encode(const llvm::json::Value &message) override {
diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h
index 6e1ffcbe1f3e3..1e0816110b80a 100644
--- a/lldb/include/lldb/Protocol/MCP/Protocol.h
+++ b/lldb/include/lldb/Protocol/MCP/Protocol.h
@@ -14,6 +14,7 @@
#ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H
#define LLDB_PROTOCOL_MCP_PROTOCOL_H
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/JSON.h"
#include <optional>
#include <string>
@@ -324,4 +325,11 @@ bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path);
} // namespace lldb_protocol::mcp
+namespace llvm::json {
+inline Value toJSON(const lldb_protocol::mcp::Void &) { return Object(); }
+inline bool fromJSON(const Value &, lldb_protocol::mcp::Void &, Path) {
+ return true;
+}
+} // namespace llvm::json
+
#endif
diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h
index 1f916ae525b5c..df2a4810ce620 100644
--- a/lldb/include/lldb/Protocol/MCP/Server.h
+++ b/lldb/include/lldb/Protocol/MCP/Server.h
@@ -9,7 +9,6 @@
#ifndef LLDB_PROTOCOL_MCP_SERVER_H
#define LLDB_PROTOCOL_MCP_SERVER_H
-#include "lldb/Host/JSONTransport.h"
#include "lldb/Host/MainLoop.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/Protocol/MCP/Resource.h"
@@ -19,74 +18,66 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/Signals.h"
-#include <functional>
#include <memory>
#include <string>
#include <vector>
namespace lldb_protocol::mcp {
-class Server : public MCPTransport::MessageHandler {
+class Server {
+
+ using MCPTransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>;
+
+ using ReadHandleUP = lldb_private::MainLoop::ReadHandleUP;
+
public:
- Server(std::string name, std::string version,
- std::unique_ptr<MCPTransport> transport_up,
- lldb_private::MainLoop &loop);
+ Server(std::string name, std::string version, LogCallback log_callback = {});
~Server() = default;
- using NotificationHandler = std::function<void(const Notification &)>;
-
void AddTool(std::unique_ptr<Tool> tool);
void AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider);
- void AddNotificationHandler(llvm::StringRef method,
- NotificationHandler handler);
- llvm::Error Run();
+ llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP);
protected:
- ServerCapabilities GetCapabilities();
-
- using RequestHandler =
- std::function<llvm::Expected<Response>(const Request &)>;
+ MCPTransport::BinderUP Bind(MCPTransport &);
- void AddRequestHandlers();
-
- void AddRequestHandler(llvm::StringRef method, RequestHandler handler);
-
- llvm::Expected<std::optional<Message>> HandleData(llvm::StringRef data);
-
- llvm::Expected<Response> Handle(const Request &request);
- void Handle(const Notification ¬ification);
-
- llvm::Expected<Response> InitializeHandler(const Request &);
+ ServerCapabilities GetCapabilities();
- llvm::Expected<Response> ToolsListHandler(const Request &);
- llvm::Expected<Response> ToolsCallHandler(const Request &);
+ llvm::Expected<InitializeResult> InitializeHandler(const InitializeParams &);
- llvm::Expected<Response> ResourcesListHandler(const Request &);
- llvm::Expected<Response> ResourcesReadHandler(const Request &);
+ llvm::Expected<ListToolsResult> ToolsListHandler();
+ llvm::Expected<CallToolResult> ToolsCallHandler(const CallToolParams &);
- void Received(const Request &) override;
- void Received(const Response &) override;
- void Received(const Notification &) override;
- void OnError(llvm::Error) override;
- void OnClosed() override;
+ llvm::Expected<ListResourcesResult> ResourcesListHandler();
+ llvm::Expected<ReadResourceResult>
+ ResourcesReadHandler(const ReadResourceParams &);
- void TerminateLoop();
+ template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) {
+ Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str());
+ }
+ void Log(llvm::StringRef message) {
+ if (m_log_callback)
+ m_log_callback(message);
+ }
private:
const std::string m_name;
const std::string m_version;
- std::unique_ptr<MCPTransport> m_transport_up;
- lldb_private::MainLoop &m_loop;
+ LogCallback m_log_callback;
+ struct Client {
+ ReadHandleUP handle;
+ MCPTransportUP transport;
+ MCPTransport::BinderUP binder;
+ };
+ std::map<MCPTransport *, Client> m_instances;
llvm::StringMap<std::unique_ptr<Tool>> m_tools;
std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers;
-
- llvm::StringMap<RequestHandler> m_request_handlers;
- llvm::StringMap<NotificationHandler> m_notification_handlers;
};
class ServerInfoHandle;
@@ -120,7 +111,7 @@ class ServerInfoHandle {
ServerInfoHandle &operator=(const ServerInfoHandle &) = delete;
/// @}
- /// Remove the file.
+ /// Remove the file on disk, if one is tracked.
void Remove();
private:
diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h
index 47c2ccfc44dfe..55b2e8fa0a7f2 100644
--- a/lldb/include/lldb/Protocol/MCP/Transport.h
+++ b/lldb/include/lldb/Protocol/MCP/Transport.h
@@ -10,22 +10,95 @@
#define LLDB_PROTOCOL_MCP_TRANSPORT_H
#include "lldb/Host/JSONTransport.h"
+#include "lldb/Protocol/MCP/MCPError.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/lldb-forward.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Error.h"
+
+namespace lldb_private {
+/// Specializations of the JSONTransport protocol functions for MCP.
+/// @{
+template <>
+inline lldb_protocol::mcp::Request
+make_request(int64_t id, llvm::StringRef method,
+ std::optional<llvm::json::Value> params) {
+ return lldb_protocol::mcp::Request{id, method.str(), params};
+}
+template <>
+inline lldb_protocol::mcp::Response
+make_response(const lldb_protocol::mcp::Request &req, llvm::Error error) {
+ lldb_protocol::mcp::Error protocol_error;
+ llvm::handleAllErrors(
+ std::move(error),
+ [&](const lldb_protocol::mcp::MCPError &err) {
+ protocol_error = err.toProtocolError();
+ },
+ [&](const llvm::ErrorInfoBase &err) {
+ protocol_error.code = lldb_protocol::mcp::MCPError::kInternalError;
+ protocol_error.message = err.message();
+ });
+
+ return lldb_protocol::mcp::Response{req.id, std::move(protocol_error)};
+}
+template <>
+inline lldb_protocol::mcp::Response
+make_response(const lldb_protocol::mcp::Request &req,
+ llvm::json::Value result) {
+ return lldb_protocol::mcp::Response{req.id, std::move(result)};
+}
+template <>
+inline lldb_protocol::mcp::Notification
+make_event(llvm::StringRef method, std::optional<llvm::json::Value> params) {
+ return lldb_protocol::mcp::Notification{method.str(), params};
+}
+template <>
+inline llvm::Expected<llvm::json::Value>
+get_result(const lldb_protocol::mcp::Response &resp) {
+ if (const lldb_protocol::mcp::Error *error =
+ std::get_if<lldb_protocol::mcp::Error>(&resp.result))
+ return llvm::make_error<lldb_protocol::mcp::MCPError>(error->message,
+ error->code);
+ return std::get<llvm::json::Value>(resp.result);
+}
+template <> inline int64_t get_id(const lldb_protocol::mcp::Response &resp) {
+ return std::get<int64_t>(resp.id);
+}
+template <>
+inline llvm::StringRef get_method(const lldb_protocol::mcp::Request &req) {
+ return req.method;
+}
+template <>
+inline llvm::StringRef get_method(const lldb_protocol::mcp::Notification &evt) {
+ return evt.method;
+}
+template <>
+inline llvm::json::Value get_params(const lldb_protocol::mcp::Request &req) {
+ return req.params;
+}
+template <>
+inline llvm::json::Value
+get_params(const lldb_protocol::mcp::Notification &evt) {
+ return evt.params;
+}
+/// @}
+
+} // end namespace lldb_private
namespace lldb_protocol::mcp {
/// Generic transport that uses the MCP protocol.
-using MCPTransport = lldb_private::Transport<Request, Response, Notification>;
+using MCPTransport =
+ lldb_private::JSONTransport<int64_t, Request, Response, Notification>;
/// Generic logging callback, to allow the MCP server / client / transport layer
/// to be independent of the lldb log implementation.
using LogCallback = llvm::unique_function<void(llvm::StringRef message)>;
class Transport final
- : public lldb_private::JSONRPCTransport<Request, Response, Notification> {
+ : public lldb_private::JSONRPCTransport<int64_t, Request, Response,
+ Notification> {
public:
Transport(lldb::IOObjectSP in, lldb::IOObjectSP out,
LogCallback log_callback = {});
diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp
index c4b42eafc85d3..f809ef478c8f7 100644
--- a/lldb/source/Host/common/JSONTransport.cpp
+++ b/lldb/source/Host/common/JSONTransport.cpp
@@ -30,3 +30,13 @@ void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const {
std::error_code TransportUnhandledContentsError::convertToErrorCode() const {
return std::make_error_code(std::errc::bad_message);
}
+
+char InvalidParams::ID;
+
+void InvalidParams::log(llvm::raw_ostream &OS) const {
+ OS << "invalid parameters for method '" << m_method << "': '" << m_context
+ << "'";
+}
+std::error_code InvalidParams::convertToErrorCode() const {
+ return std::make_error_code(std::errc::invalid_argument);
+}
diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
index d3af3cf25c4a1..46a7a96cc5fc0 100644
--- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
+++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
@@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
}
void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
- server.AddNotificationHandler("notifications/initialized",
- [](const lldb_protocol::mcp::Notification &) {
- LLDB_LOG(GetLog(LLDBLog::Host),
- "MCP initialization complete");
- });
server.AddTool(
std::make_unique<CommandTool>("command", "Run an lldb command."));
server.AddTool(std::make_unique<DebuggerListTool>(
@@ -66,7 +61,7 @@ void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
Log *log = GetLog(LLDBLog::Host);
- std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1);
+ std::string client_name = llvm::formatv("client_{0}", m_client_count++);
LLDB_LOG(log, "New MCP client connected: {0}", client_name);
lldb::IOObjectSP io_sp = std::move(socket);
@@ -74,16 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
io_sp, io_sp, [client_name](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message);
});
- auto instance_up = std::make_unique<lldb_protocol::mcp::Server>(
- std::string(kName), std::string(kVersion), std::move(transport_up),
- m_loop);
- Extend(*instance_up);
- llvm::Error error = instance_up->Run();
- if (error) {
- LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}");
- return;
- }
- m_instances.push_back(std::move(instance_up));
+
+ if (auto error = m_server->Accept(m_loop, std::move(transport_up)))
+ LLDB_LOG_ERROR(log, std::move(error), "{0}:");
}
llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
@@ -114,13 +102,20 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
llvm::join(m_listener->GetListeningConnectionURI(), ", ");
ServerInfo info{listening_uris[0]};
- llvm::Expected<ServerInfoHandle> handle = ServerInfo::Write(info);
- if (!handle)
- return handle.takeError();
+ llvm::Expected<ServerInfoHandle> server_info_handle = ServerInfo::Write(info);
+ if (!server_info_handle)
+ return server_info_handle.takeError();
+
+ m_client_count = 0;
+ m_server = std::make_unique<lldb_protocol::mcp::Server>(
+ std::string(kName), std::string(kVersion), [](StringRef message) {
+ LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message);
+ });
+ Extend(*m_server);
m_running = true;
- m_server_info_handle = std::move(*handle);
- m_listen_handlers = std::move(*handles);
+ m_server_info_handle = std::move(*server_info_handle);
+ m_accept_handles = std::move(*handles);
m_loop_thread = std::thread([=] {
llvm::set_thread_name("protocol-server.mcp");
m_loop.Run();
@@ -145,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() {
if (m_loop_thread.joinable())
m_loop_thread.join();
+ m_accept_handles.clear();
+
+ m_server.reset(nullptr);
m_server_info_handle.Remove();
- m_listen_handlers.clear();
- m_instances.clear();
return llvm::Error::success();
}
diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
index 0251664a2acc4..d34b22e29765f 100644
--- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
+++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
@@ -12,19 +12,23 @@
#include "lldb/Core/ProtocolServer.h"
#include "lldb/Host/MainLoop.h"
#include "lldb/Host/Socket.h"
-#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/Protocol/MCP/Server.h"
#include <thread>
namespace lldb_private::mcp {
class ProtocolServerMCP : public ProtocolServer {
+
+ using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>;
+
+ using ReadHandleUP = MainLoop::ReadHandleUP;
+
public:
ProtocolServerMCP();
- virtual ~ProtocolServerMCP() override;
+ ~ProtocolServerMCP() override;
- virtual llvm::Error Start(ProtocolServer::Connection connection) override;
- virtual llvm::Error Stop() override;
+ llvm::Error Start(ProtocolServer::Connection connection) override;
+ llvm::Error Stop() override;
static void Initialize();
static void Terminate();
@@ -48,16 +52,18 @@ class ProtocolServerMCP : public ProtocolServer {
bool m_running = false;
- lldb_protocol::mcp::ServerInfoHandle m_server_info_handle;
lldb_private::MainLoop m_loop;
std::thread m_loop_thread;
+ unsigned m_client_count = 0;
std::mutex m_mutex;
std::unique_ptr<Socket> m_listener;
+ std::vector<ReadHandleUP> m_accept_handles;
- std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers;
- std::vector<std::unique_ptr<lldb_protocol::mcp::Server>> m_instances;
+ ServerUP m_server;
+ lldb_protocol::mcp::ServerInfoHandle m_server_info_handle;
};
+
} // namespace lldb_private::mcp
#endif
diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp
index a08874e7321af..d3c970e6b7efc 100644
--- a/lldb/source/Protocol/MCP/Server.cpp
+++ b/lldb/source/Protocol/MCP/Server.cpp
@@ -13,6 +13,7 @@
#include "lldb/Host/JSONTransport.h"
#include "lldb/Protocol/MCP/MCPError.h"
#include "lldb/Protocol/MCP/Protocol.h"
+#include "lldb/Protocol/MCP/Transport.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h"
@@ -109,48 +110,9 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() {
return infos;
}
-Server::Server(std::string name, std::string version,
- std::unique_ptr<MCPTransport> transport_up,
- lldb_private::MainLoop &loop)
+Server::Server(std::string name, std::string version, LogCallback log_callback)
: m_name(std::move(name)), m_version(std::move(version)),
- m_transport_up(std::move(transport_up)), m_loop(loop) {
- AddRequestHandlers();
-}
-
-void Server::AddRequestHandlers() {
- AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this,
- std::placeholders::_1));
- AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler,
- this, std::placeholders::_1));
- AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler,
- this, std::placeholders::_1));
-}
-
-llvm::Expected<Response> Server::Handle(const Request &request) {
- auto it = m_request_handlers.find(request.method);
- if (it != m_request_handlers.end()) {
- llvm::Expected<Response> response = it->second(request);
- if (!response)
- return response;
- response->id = request.id;
- return *response;
- }
-
- return llvm::make_error<MCPError>(
- llvm::formatv("no handler for request: {0}", request.method).str());
-}
-
-void Server::Handle(const Notification ¬ification) {
- auto it = m_notification_handlers.find(notification.method);
- if (it != m_notification_handlers.end()) {
- it->second(notification);
- return;
- }
-}
+ m_log_callback(std::move(log_callback)) {}
void Server::AddTool(std::unique_ptr<Tool> tool) {
if (!tool)
@@ -165,48 +127,62 @@ void Server::AddResourceProvider(
m_resource_providers.push_back(std::move(resource_provider));
}
-void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) {
- m_request_handlers[method] = std::move(handler);
-}
+MCPTransport::BinderUP Server::Bind(MCPTransport &transport) {
+ MCPTransport::BinderUP binder =
+ std::make_unique<MCPTransport::Binder>(transport);
+ binder->bind<InitializeResult, InitializeParams>(
+ "initialize", &Server::InitializeHandler, this);
+ binder->bind<ListToolsResult, void>("tools/list", &Server::ToolsListHandler,
+ this);
+ binder->bind<CallToolResult, CallToolParams>("tools/call",
+ &Server::ToolsCallHandler, this);
+ binder->bind<ListResourcesResult, void>("resources/list",
+ &Server::ResourcesListHandler, this);
+ binder->bind<ReadResourceResult, ReadResourceParams>(
+ "resources/read", &Server::ResourcesReadHandler, this);
+ binder->bind<void>("notifications/initialized",
+ [this]() { Log("MCP initialization complete"); });
+ return binder;
+}
+
+llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) {
+ MCPTransport::BinderUP binder = Bind(*transport);
+ MCPTransport *transport_ptr = transport.get();
+ binder->disconnected([this, transport_ptr]() {
+ assert(m_instances.find(transport_ptr) != m_instances.end() &&
+ "Client not found in m_instances");
+ m_instances.erase(transport_ptr);
+ });
+
+ auto handle = transport->RegisterMessageHandler(loop, *binder);
+ if (!handle)
+ return handle.takeError();
-void Server::AddNotificationHandler(llvm::StringRef method,
- NotificationHandler handler) {
- m_notification_handlers[method] = std::move(handler);
+ m_instances[transport_ptr] =
+ Client{std::move(*handle), std::move(transport), std::move(binder)};
+ return llvm::Error::success();
}
-llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
- Response response;
+Expected<InitializeResult>
+Server::InitializeHandler(const InitializeParams &request) {
InitializeResult result;
result.protocolVersion = mcp::kProtocolVersion;
result.capabilities = GetCapabilities();
result.serverInfo.name = m_name;
result.serverInfo.version = m_version;
- response.result = std::move(result);
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
- Response response;
-
+llvm::Expected<ListToolsResult> Server::ToolsListHandler() {
ListToolsResult result;
for (const auto &tool : m_tools)
result.tools.emplace_back(tool.second->GetDefinition());
- response.result = std::move(result);
-
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no tool parameters");
- CallToolParams params;
- json::Path::Root root("params");
- if (!fromJSON(request.params, params, root))
- return root.getError();
-
+llvm::Expected<CallToolResult>
+Server::ToolsCallHandler(const CallToolParams ¶ms) {
llvm::StringRef tool_name = params.name;
if (tool_name.empty())
return llvm::createStringError("no tool name");
@@ -223,125 +199,50 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
if (!text_result)
return text_result.takeError();
- response.result = toJSON(*text_result);
-
- return response;
+ return text_result;
}
-llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
- Response response;
-
+llvm::Expected<ListResourcesResult> Server::ResourcesListHandler() {
ListResourcesResult result;
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers)
for (const Resource &resource : resource_provider_up->GetResources())
result.resources.push_back(resource);
- response.result = std::move(result);
-
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no resource parameters");
-
- ReadResourceParams params;
- json::Path::Root root("params");
- if (!fromJSON(request.params, params, root))
- return root.getError();
-
- llvm::StringRef uri_str = params.uri;
+Expected<ReadResourceResult>
+Server::ResourcesReadHandler(const ReadResourceParams ¶ms) {
+ StringRef uri_str = params.uri;
if (uri_str.empty())
- return llvm::createStringError("no resource uri");
+ return createStringError("no resource uri");
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers) {
- llvm::Expected<ReadResourceResult> result =
+ Expected<ReadResourceResult> result =
resource_provider_up->ReadResource(uri_str);
if (result.errorIsA<UnsupportedURI>()) {
- llvm::consumeError(result.takeError());
+ consumeError(result.takeError());
continue;
}
if (!result)
return result.takeError();
- Response response;
- response.result = std::move(*result);
- return response;
+ return *result;
}
return make_error<MCPError>(
- llvm::formatv("no resource handler for uri: {0}", uri_str).str(),
+ formatv("no resource handler for uri: {0}", uri_str).str(),
MCPError::kResourceNotFound);
}
ServerCapabilities Server::GetCapabilities() {
lldb_protocol::mcp::ServerCapabilities capabilities;
capabilities.supportsToolsList = true;
+ capabilities.supportsResourcesList = true;
// FIXME: Support sending notifications when a debugger/target are
// added/removed.
- capabilities.supportsResourcesList = false;
+ capabilities.supportsResourcesSubscribe = false;
return capabilities;
}
-
-llvm::Error Server::Run() {
- auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this);
- if (!handle)
- return handle.takeError();
-
- lldb_private::Status status = m_loop.Run();
- if (status.Fail())
- return status.takeError();
-
- return llvm::Error::success();
-}
-
-void Server::Received(const Request &request) {
- auto SendResponse = [this](const Response &response) {
- if (llvm::Error error = m_transport_up->Send(response))
- m_transport_up->Log(llvm::toString(std::move(error)));
- };
-
- llvm::Expected<Response> response = Handle(request);
- if (response)
- return SendResponse(*response);
-
- lldb_protocol::mcp::Error protocol_error;
- llvm::handleAllErrors(
- response.takeError(),
- [&](const MCPError &err) { protocol_error = err.toProtocolError(); },
- [&](const llvm::ErrorInfoBase &err) {
- protocol_error.code = MCPError::kInternalError;
- protocol_error.message = err.message();
- });
- Response error_response;
- error_response.id = request.id;
- error_response.result = std::move(protocol_error);
- SendResponse(error_response);
-}
-
-void Server::Received(const Response &response) {
- m_transport_up->Log("unexpected MCP message: response");
-}
-
-void Server::Received(const Notification ¬ification) {
- Handle(notification);
-}
-
-void Server::OnError(llvm::Error error) {
- m_transport_up->Log(llvm::toString(std::move(error)));
- TerminateLoop();
-}
-
-void Server::OnClosed() {
- m_transport_up->Log("EOF");
- TerminateLoop();
-}
-
-void Server::TerminateLoop() {
- m_loop.AddPendingCallback(
- [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
-}
diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h
index 71681fd4b51ed..0c921e5b72d74 100644
--- a/lldb/tools/lldb-dap/DAP.h
+++ b/lldb/tools/lldb-dap/DAP.h
@@ -79,10 +79,10 @@ enum DAPBroadcasterBits {
enum class ReplMode { Variable = 0, Command, Auto };
using DAPTransport =
- lldb_private::Transport<protocol::Request, protocol::Response,
- protocol::Event>;
+ lldb_private::JSONTransport<protocol::Id, protocol::Request,
+ protocol::Response, protocol::Event>;
-struct DAP final : private DAPTransport::MessageHandler {
+struct DAP final : public DAPTransport::MessageHandler {
/// Path to the lldb-dap binary itself.
static llvm::StringRef debug_adapter_path;
diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h
index 0a9ef538a7398..92e41b1dbf595 100644
--- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h
+++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h
@@ -30,6 +30,8 @@ namespace lldb_dap::protocol {
// MARK: Base Protocol
+using Id = int64_t;
+
/// A client or debug adapter initiated request.
struct Request {
/// Sequence number of the message (also known as message ID). The `seq` for
@@ -39,7 +41,7 @@ struct Request {
/// associate requests with their corresponding responses. For protocol
/// messages of type `request` the sequence number can be used to cancel the
/// request.
- int64_t seq;
+ Id seq;
/// The command to execute.
std::string command;
@@ -76,7 +78,7 @@ enum ResponseMessage : unsigned {
/// Response for a request.
struct Response {
/// Sequence number of the corresponding request.
- int64_t request_seq;
+ Id request_seq;
/// The command requested.
std::string command;
diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h
index 4a9dd76c2303e..6462c155eb9af 100644
--- a/lldb/tools/lldb-dap/Transport.h
+++ b/lldb/tools/lldb-dap/Transport.h
@@ -24,9 +24,9 @@ namespace lldb_dap {
/// A transport class that performs the Debug Adapter Protocol communication
/// with the client.
-class Transport final
- : public lldb_private::HTTPDelimitedJSONTransport<
- protocol::Request, protocol::Response, protocol::Event> {
+class Transport final : public lldb_private::HTTPDelimitedJSONTransport<
+ protocol::Id, protocol::Request, protocol::Response,
+ protocol::Event> {
public:
Transport(llvm::StringRef client_name, lldb_dap::Log *log,
lldb::IOObjectSP input, lldb::IOObjectSP output);
diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp
index 2090fe6896d6b..4fd6cd546e6fa 100644
--- a/lldb/unittests/DAP/DAPTest.cpp
+++ b/lldb/unittests/DAP/DAPTest.cpp
@@ -9,13 +9,10 @@
#include "DAP.h"
#include "Protocol/ProtocolBase.h"
#include "TestBase.h"
-#include "llvm/Testing/Support/Error.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <optional>
-using namespace llvm;
-using namespace lldb;
using namespace lldb_dap;
using namespace lldb_dap_tests;
using namespace lldb_dap::protocol;
@@ -24,18 +21,7 @@ using namespace testing;
class DAPTest : public TransportBase {};
TEST_F(DAPTest, SendProtocolMessages) {
- DAP dap{
- /*log=*/nullptr,
- /*default_repl_mode=*/ReplMode::Auto,
- /*pre_init_commands=*/{},
- /*no_lldbinit=*/false,
- /*client_name=*/"test_client",
- /*transport=*/*transport,
- /*loop=*/loop,
- };
- dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt});
- loop.AddPendingCallback(
- [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
- EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt)));
- ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded());
+ dap->Send(Event{/*event=*/"my-event", /*body=*/std::nullopt});
+ EXPECT_CALL(client, Received(IsEvent("my-event")));
+ Run();
}
diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp
index c6ff1f90b01d5..88d6e9a69eca3 100644
--- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp
+++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp
@@ -31,7 +31,7 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) {
DisconnectRequestHandler handler(*dap);
ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded());
EXPECT_CALL(client, Received(IsEvent("terminated", _)));
- RunOnce();
+ Run();
}
TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) {
@@ -53,5 +53,5 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) {
EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n")));
EXPECT_CALL(client, Received(Output("Running terminateCommands:\n")));
EXPECT_CALL(client, Received(IsEvent("terminated", _)));
- RunOnce();
+ Run();
}
diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp
index ba7baf2103799..3721e09d8b699 100644
--- a/lldb/unittests/DAP/TestBase.cpp
+++ b/lldb/unittests/DAP/TestBase.cpp
@@ -32,23 +32,9 @@ using lldb_private::FileSystem;
using lldb_private::MainLoop;
using lldb_private::Pipe;
-Expected<MainLoop::ReadHandleUP>
-TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) {
- Expected<lldb::FileUP> dummy_file = FileSystem::Instance().Open(
- FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite);
- if (!dummy_file)
- return dummy_file.takeError();
- m_dummy_file = std::move(*dummy_file);
- lldb_private::Status status;
- auto handle = loop.RegisterReadObject(
- m_dummy_file, [](lldb_private::MainLoopBase &) {}, status);
- if (status.Fail())
- return status.takeError();
- return handle;
-}
+void TransportBase::SetUp() {
+ std::tie(to_client, to_server) = TestDAPTransport::createPair();
-void DAPTestBase::SetUp() {
- TransportBase::SetUp();
std::error_code EC;
log = std::make_unique<Log>("-", EC);
dap = std::make_unique<DAP>(
@@ -57,16 +43,30 @@ void DAPTestBase::SetUp() {
/*pre_init_commands=*/std::vector<std::string>(),
/*no_lldbinit=*/false,
/*client_name=*/"test_client",
- /*transport=*/*transport, /*loop=*/loop);
+ /*transport=*/*to_client, /*loop=*/loop);
+
+ auto server_handle = to_server->RegisterMessageHandler(loop, *dap.get());
+ EXPECT_THAT_EXPECTED(server_handle, Succeeded());
+ handles[0] = std::move(*server_handle);
+
+ auto client_handle = to_client->RegisterMessageHandler(loop, client);
+ EXPECT_THAT_EXPECTED(client_handle, Succeeded());
+ handles[1] = std::move(*client_handle);
}
+void TransportBase::Run() {
+ loop.AddPendingCallback(
+ [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
+ EXPECT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded());
+}
+
+void DAPTestBase::SetUp() { TransportBase::SetUp(); }
+
void DAPTestBase::TearDown() {
- if (core) {
+ if (core)
ASSERT_THAT_ERROR(core->discard(), Succeeded());
- }
- if (binary) {
+ if (binary)
ASSERT_THAT_ERROR(binary->discard(), Succeeded());
- }
}
void DAPTestBase::SetUpTestSuite() {
diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h
index c19eead4e37e7..aaeab3b3d2cd9 100644
--- a/lldb/unittests/DAP/TestBase.h
+++ b/lldb/unittests/DAP/TestBase.h
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "DAP.h"
+#include "DAPLog.h"
#include "Protocol/ProtocolBase.h"
#include "TestingSupport/Host/JSONTransportTestUtilities.h"
#include "TestingSupport/SubsystemRAII.h"
@@ -14,66 +15,41 @@
#include "lldb/Host/HostInfo.h"
#include "lldb/Host/MainLoop.h"
#include "lldb/Host/MainLoopBase.h"
-#include "lldb/lldb-forward.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h"
-#include "llvm/Testing/Support/Error.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <memory>
+#include <optional>
+
+/// Helpers for gtest printing.
+namespace lldb_dap::protocol {
+
+inline void PrintTo(const Request &req, std::ostream *os) {
+ *os << llvm::formatv("{0}", toJSON(req)).str();
+}
+
+inline void PrintTo(const Response &resp, std::ostream *os) {
+ *os << llvm::formatv("{0}", toJSON(resp)).str();
+}
+
+inline void PrintTo(const Event &evt, std::ostream *os) {
+ *os << llvm::formatv("{0}", toJSON(evt)).str();
+}
+
+inline void PrintTo(const Message &message, std::ostream *os) {
+ return std::visit([os](auto &&message) { return PrintTo(message, os); },
+ message);
+}
+
+} // namespace lldb_dap::protocol
namespace lldb_dap_tests {
-class TestTransport final
- : public lldb_private::Transport<lldb_dap::protocol::Request,
- lldb_dap::protocol::Response,
- lldb_dap::protocol::Event> {
-public:
- using Message = lldb_private::Transport<lldb_dap::protocol::Request,
- lldb_dap::protocol::Response,
- lldb_dap::protocol::Event>::Message;
-
- TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler)
- : m_loop(loop), m_handler(handler) {}
-
- llvm::Error Send(const lldb_dap::protocol::Event &e) override {
- m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) {
- this->m_handler.Received(e);
- });
- return llvm::Error::success();
- }
-
- llvm::Error Send(const lldb_dap::protocol::Request &r) override {
- m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) {
- this->m_handler.Received(r);
- });
- return llvm::Error::success();
- }
-
- llvm::Error Send(const lldb_dap::protocol::Response &r) override {
- m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) {
- this->m_handler.Received(r);
- });
- return llvm::Error::success();
- }
-
- llvm::Expected<lldb_private::MainLoop::ReadHandleUP>
- RegisterMessageHandler(lldb_private::MainLoop &loop,
- MessageHandler &handler) override;
-
- void Log(llvm::StringRef message) override {
- log_messages.emplace_back(message);
- }
-
- std::vector<std::string> log_messages;
-
-private:
- lldb_private::MainLoop &m_loop;
- MessageHandler &m_handler;
- lldb::FileSP m_dummy_file;
-};
+using TestDAPTransport =
+ TestTransport<int64_t, lldb_dap::protocol::Request,
+ lldb_dap::protocol::Response, lldb_dap::protocol::Event>;
/// A base class for tests that need transport configured for communicating DAP
/// messages.
@@ -82,22 +58,38 @@ class TransportBase : public testing::Test {
lldb_private::SubsystemRAII<lldb_private::FileSystem, lldb_private::HostInfo>
subsystems;
lldb_private::MainLoop loop;
- std::unique_ptr<TestTransport> transport;
- MockMessageHandler<lldb_dap::protocol::Request, lldb_dap::protocol::Response,
- lldb_dap::protocol::Event>
+ lldb_private::MainLoop::ReadHandleUP handles[2];
+
+ std::unique_ptr<lldb_dap::Log> log;
+
+ std::unique_ptr<TestDAPTransport> to_client;
+ MockMessageHandler<int64_t, lldb_dap::protocol::Request,
+ lldb_dap::protocol::Response, lldb_dap::protocol::Event>
client;
- void SetUp() override {
- transport = std::make_unique<TestTransport>(loop, client);
- }
+ std::unique_ptr<TestDAPTransport> to_server;
+ std::unique_ptr<lldb_dap::DAP> dap;
+
+ void SetUp() override;
+
+ void Run();
};
/// A matcher for a DAP event.
-template <typename M1, typename M2>
+template <typename EventMatcher, typename BodyMatcher>
inline testing::Matcher<const lldb_dap::protocol::Event &>
-IsEvent(const M1 &m1, const M2 &m2) {
- return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1),
- testing::Field(&lldb_dap::protocol::Event::body, m2));
+IsEvent(const EventMatcher &event_matcher, const BodyMatcher &body_matcher) {
+ return testing::AllOf(
+ testing::Field(&lldb_dap::protocol::Event::event, event_matcher),
+ testing::Field(&lldb_dap::protocol::Event::body, body_matcher));
+}
+
+template <typename EventMatcher>
+inline testing::Matcher<const lldb_dap::protocol::Event &>
+IsEvent(const EventMatcher &event_matcher) {
+ return testing::AllOf(
+ testing::Field(&lldb_dap::protocol::Event::event, event_matcher),
+ testing::Field(&lldb_dap::protocol::Event::body, std::nullopt));
}
/// Matches an "output" event.
@@ -110,8 +102,6 @@ inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") {
/// A base class for tests that interact with a `lldb_dap::DAP` instance.
class DAPTestBase : public TransportBase {
protected:
- std::unique_ptr<lldb_dap::Log> log;
- std::unique_ptr<lldb_dap::DAP> dap;
std::optional<llvm::sys::fs::TempFile> core;
std::optional<llvm::sys::fs::TempFile> binary;
@@ -126,12 +116,6 @@ class DAPTestBase : public TransportBase {
bool GetDebuggerSupportsTarget(llvm::StringRef platform);
void CreateDebugger();
void LoadCore();
-
- void RunOnce() {
- loop.AddPendingCallback(
- [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
- ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded());
- }
};
} // namespace lldb_dap_tests
diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp
index 445674f402252..0228e13b61909 100644
--- a/lldb/unittests/Host/JSONTransportTest.cpp
+++ b/lldb/unittests/Host/JSONTransportTest.cpp
@@ -9,6 +9,7 @@
#include "lldb/Host/JSONTransport.h"
#include "TestingSupport/Host/JSONTransportTestUtilities.h"
#include "TestingSupport/Host/PipeTestUtilities.h"
+#include "TestingSupport/SubsystemRAII.h"
#include "lldb/Host/File.h"
#include "lldb/Host/MainLoop.h"
#include "lldb/Host/MainLoopBase.h"
@@ -25,6 +26,7 @@
#include <chrono>
#include <cstddef>
#include <memory>
+#include <optional>
#include <string>
using namespace llvm;
@@ -32,20 +34,35 @@ using namespace lldb_private;
using testing::_;
using testing::HasSubstr;
using testing::InSequence;
+using testing::Ref;
+
+namespace llvm::json {
+static bool fromJSON(const Value &V, Value &T, Path P) {
+ T = V;
+ return true;
+}
+} // namespace llvm::json
namespace {
namespace test_protocol {
struct Req {
+ int id = 0;
std::string name;
+ std::optional<json::Value> params;
};
-json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; }
+json::Value toJSON(const Req &T) {
+ return json::Object{{"name", T.name}, {"id", T.id}, {"params", T.params}};
+}
bool fromJSON(const json::Value &V, Req &T, json::Path P) {
json::ObjectMapper O(V, P);
- return O && O.map("req", T.name);
+ return O && O.map("name", T.name) && O.map("id", T.id) &&
+ O.map("params", T.params);
+}
+bool operator==(const Req &a, const Req &b) {
+ return a.name == b.name && a.id == b.id && a.params == b.params;
}
-bool operator==(const Req &a, const Req &b) { return a.name == b.name; }
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) {
OS << toJSON(V);
return OS;
@@ -58,14 +75,19 @@ void PrintTo(const Req &message, std::ostream *os) {
}
struct Resp {
- std::string name;
+ int id = 0;
+ std::optional<json::Value> result;
};
-json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; }
+json::Value toJSON(const Resp &T) {
+ return json::Object{{"id", T.id}, {"result", T.result}};
+}
bool fromJSON(const json::Value &V, Resp &T, json::Path P) {
json::ObjectMapper O(V, P);
- return O && O.map("resp", T.name);
+ return O && O.map("id", T.id) && O.map("result", T.result);
+}
+bool operator==(const Resp &a, const Resp &b) {
+ return a.id == b.id && a.result == b.result;
}
-bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; }
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) {
OS << toJSON(V);
return OS;
@@ -79,11 +101,14 @@ void PrintTo(const Resp &message, std::ostream *os) {
struct Evt {
std::string name;
+ std::optional<json::Value> params;
};
-json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; }
+json::Value toJSON(const Evt &T) {
+ return json::Object{{"name", T.name}, {"params", T.params}};
+}
bool fromJSON(const json::Value &V, Evt &T, json::Path P) {
json::ObjectMapper O(V, P);
- return O && O.map("evt", T.name);
+ return O && O.map("name", T.name) && O.map("params", T.params);
}
bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; }
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) {
@@ -107,41 +132,64 @@ bool fromJSON(const json::Value &V, Message &msg, json::Path P) {
P.report("expected object");
return false;
}
- if (O->get("req")) {
- Req R;
- if (!fromJSON(V, R, P))
+
+ if (O->find("id") == O->end()) {
+ Evt E;
+ if (!fromJSON(V, E, P))
return false;
- msg = std::move(R);
+ msg = std::move(E);
return true;
}
- if (O->get("resp")) {
- Resp R;
+
+ if (O->get("name")) {
+ Req R;
if (!fromJSON(V, R, P))
return false;
msg = std::move(R);
return true;
}
- if (O->get("evt")) {
- Evt E;
- if (!fromJSON(V, E, P))
- return false;
- msg = std::move(E);
- return true;
- }
- P.report("unknown message type");
- return false;
+ Resp R;
+ if (!fromJSON(V, R, P))
+ return false;
+
+ msg = std::move(R);
+ return true;
}
-} // namespace test_protocol
+struct MyFnParams {
+ int a = 0;
+ int b = 0;
+};
+json::Value toJSON(const MyFnParams &T) {
+ return json::Object{{"a", T.a}, {"b", T.b}};
+}
+bool fromJSON(const json::Value &V, MyFnParams &T, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("a", T.a) && O.map("b", T.b);
+}
-template <typename T, typename Req, typename Resp, typename Evt>
-class JSONTransportTest : public PipePairTest {
+struct MyFnResult {
+ int c = 0;
+};
+json::Value toJSON(const MyFnResult &T) { return json::Object{{"c", T.c}}; }
+bool fromJSON(const json::Value &V, MyFnResult &T, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("c", T.c);
+}
+
+using Transport = TestTransport<int, Req, Resp, Evt>;
+using MessageHandler = MockMessageHandler<int, Req, Resp, Evt>;
+
+} // namespace test_protocol
+template <typename T> class JSONTransportTest : public PipePairTest {
protected:
- MockMessageHandler<Req, Resp, Evt> message_handler;
+ SubsystemRAII<FileSystem> subsystems;
+
+ test_protocol::MessageHandler message_handler;
std::unique_ptr<T> transport;
MainLoop loop;
@@ -191,8 +239,8 @@ class JSONTransportTest : public PipePairTest {
};
class TestHTTPDelimitedJSONTransport final
- : public HTTPDelimitedJSONTransport<test_protocol::Req, test_protocol::Resp,
- test_protocol::Evt> {
+ : public HTTPDelimitedJSONTransport<
+ int, test_protocol::Req, test_protocol::Resp, test_protocol::Evt> {
public:
using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport;
@@ -204,9 +252,7 @@ class TestHTTPDelimitedJSONTransport final
};
class HTTPDelimitedJSONTransportTest
- : public JSONTransportTest<TestHTTPDelimitedJSONTransport,
- test_protocol::Req, test_protocol::Resp,
- test_protocol::Evt> {
+ : public JSONTransportTest<TestHTTPDelimitedJSONTransport> {
public:
using JSONTransportTest::JSONTransportTest;
@@ -222,7 +268,7 @@ class HTTPDelimitedJSONTransportTest
};
class TestJSONRPCTransport final
- : public JSONRPCTransport<test_protocol::Req, test_protocol::Resp,
+ : public JSONRPCTransport<int, test_protocol::Req, test_protocol::Resp,
test_protocol::Evt> {
public:
using JSONRPCTransport::JSONRPCTransport;
@@ -234,9 +280,7 @@ class TestJSONRPCTransport final
std::vector<std::string> log_messages;
};
-class JSONRPCTransportTest
- : public JSONTransportTest<TestJSONRPCTransport, test_protocol::Req,
- test_protocol::Resp, test_protocol::Evt> {
+class JSONRPCTransportTest : public JSONTransportTest<TestJSONRPCTransport> {
public:
using JSONTransportTest::JSONTransportTest;
@@ -248,8 +292,71 @@ class JSONRPCTransportTest
}
};
+class TestTransportBinder : public testing::Test {
+protected:
+ SubsystemRAII<FileSystem> subsystems;
+
+ std::unique_ptr<test_protocol::Transport> to_remote;
+ std::unique_ptr<test_protocol::Transport> from_remote;
+ std::unique_ptr<test_protocol::Transport::Binder> binder;
+ test_protocol::MessageHandler remote;
+ MainLoop loop;
+
+ void SetUp() override {
+ std::tie(to_remote, from_remote) = test_protocol::Transport::createPair();
+ binder = std::make_unique<test_protocol::Transport::Binder>(*to_remote);
+
+ auto binder_handle = to_remote->RegisterMessageHandler(loop, remote);
+ EXPECT_THAT_EXPECTED(binder_handle, Succeeded());
+
+ auto remote_handle = from_remote->RegisterMessageHandler(loop, *binder);
+ EXPECT_THAT_EXPECTED(remote_handle, Succeeded());
+ }
+
+ void Run() {
+ loop.AddPendingCallback([](auto &loop) { loop.RequestTermination(); });
+ EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded());
+ }
+};
+
} // namespace
+namespace lldb_private {
+using namespace test_protocol;
+template <>
+inline test_protocol::Req make_request(int id, llvm::StringRef method,
+ std::optional<json::Value> params) {
+ return test_protocol::Req{id, method.str(), params};
+}
+template <> inline Resp make_response(const Req &req, llvm::Error error) {
+ llvm::consumeError(std::move(error));
+ return Resp{req.id, std::nullopt};
+}
+template <> inline Resp make_response(const Req &req, json::Value result) {
+ return Resp{req.id, std::move(result)};
+}
+template <>
+inline Evt make_event(llvm::StringRef method,
+ std::optional<json::Value> params) {
+ return Evt{method.str(), params};
+}
+
+template <> inline llvm::Expected<json::Value> get_result(const Resp &resp) {
+ return resp.result;
+}
+
+template <> inline int get_id(const Resp &resp) { return resp.id; }
+template <> inline llvm::StringRef get_method(const Req &req) {
+ return req.name;
+}
+template <> inline llvm::StringRef get_method(const Evt &evt) {
+ return evt.name;
+}
+template <> inline json::Value get_params(const Req &req) { return req.params; }
+template <> inline json::Value get_params(const Evt &evt) { return evt.params; }
+
+} // namespace lldb_private
+
// Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446.
#ifndef _WIN32
using namespace test_protocol;
@@ -269,35 +376,47 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) {
}
TEST_F(HTTPDelimitedJSONTransportTest, Read) {
- Write(Req{"foo"});
- EXPECT_CALL(message_handler, Received(Req{"foo"}));
+ Write(Req{6, "foo", std::nullopt});
+ EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt}));
ASSERT_THAT_ERROR(Run(), Succeeded());
}
TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) {
InSequence seq;
- Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}});
- EXPECT_CALL(message_handler, Received(Req{"one"}));
- EXPECT_CALL(message_handler, Received(Evt{"two"}));
- EXPECT_CALL(message_handler, Received(Resp{"three"}));
+ Write(
+ Message{
+ Req{6, "one", std::nullopt},
+ },
+ Message{
+ Evt{"two", std::nullopt},
+ },
+ Message{
+ Resp{2, std::nullopt},
+ });
+ EXPECT_CALL(message_handler, Received(Req{6, "one", std::nullopt}));
+ EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt}));
+ EXPECT_CALL(message_handler, Received(Resp{2, std::nullopt}));
ASSERT_THAT_ERROR(Run(), Succeeded());
}
TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) {
std::string long_str = std::string(
- HTTPDelimitedJSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
- Write(Req{long_str});
- EXPECT_CALL(message_handler, Received(Req{long_str}));
+ HTTPDelimitedJSONTransport<int, test_protocol::Req, test_protocol::Resp,
+ test_protocol::Evt>::kReadBufferSize *
+ 2,
+ 'x');
+ Write(Req{5, long_str, std::nullopt});
+ EXPECT_CALL(message_handler, Received(Req{5, long_str, std::nullopt}));
ASSERT_THAT_ERROR(Run(), Succeeded());
}
TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) {
- std::string message = Encode(Req{"foo"});
+ std::string message = Encode(Req{5, "foo", std::nullopt});
auto split_at = message.size() / 2;
std::string part1 = message.substr(0, split_at);
std::string part2 = message.substr(split_at);
- EXPECT_CALL(message_handler, Received(Req{"foo"}));
+ EXPECT_CALL(message_handler, Received(Req{5, "foo", std::nullopt}));
ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded());
loop.AddPendingCallback(
@@ -309,12 +428,12 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) {
}
TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) {
- std::string message = Encode(Req{"foo"});
+ std::string message = Encode(Req{6, "foo", std::nullopt});
auto split_at = message.size() / 2;
std::string part1 = message.substr(0, split_at);
std::string part2 = message.substr(split_at);
- EXPECT_CALL(message_handler, Received(Req{"foo"}));
+ EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt}));
ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded());
@@ -366,20 +485,21 @@ TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) {
}
TEST_F(HTTPDelimitedJSONTransportTest, Write) {
- ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded());
- ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded());
- ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded());
+ ASSERT_THAT_ERROR(transport->Send(Req{7, "foo", std::nullopt}), Succeeded());
+ ASSERT_THAT_ERROR(transport->Send(Resp{5, "bar"}), Succeeded());
+ ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded());
output.CloseWriteFileDescriptor();
char buf[1024];
Expected<size_t> bytes_read =
output.Read(buf, sizeof(buf), std::chrono::milliseconds(1));
ASSERT_THAT_EXPECTED(bytes_read, Succeeded());
- ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n"
- R"({"req":"foo"})"
- "Content-Length: 14\r\n\r\n"
- R"({"resp":"bar"})"
- "Content-Length: 13\r\n\r\n"
- R"({"evt":"baz"})"));
+ ASSERT_EQ(StringRef(buf, *bytes_read),
+ StringRef("Content-Length: 35\r\n\r\n"
+ R"({"id":7,"name":"foo","params":null})"
+ "Content-Length: 23\r\n\r\n"
+ R"({"id":5,"result":"bar"})"
+ "Content-Length: 28\r\n\r\n"
+ R"({"name":"baz","params":null})"));
}
TEST_F(JSONRPCTransportTest, MalformedRequests) {
@@ -395,17 +515,18 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) {
}
TEST_F(JSONRPCTransportTest, Read) {
- Write(Message{Req{"foo"}});
- EXPECT_CALL(message_handler, Received(Req{"foo"}));
+ Write(Message{Req{1, "foo", std::nullopt}});
+ EXPECT_CALL(message_handler, Received(Req{1, "foo", std::nullopt}));
ASSERT_THAT_ERROR(Run(), Succeeded());
}
TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) {
InSequence seq;
- Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}});
- EXPECT_CALL(message_handler, Received(Req{"one"}));
- EXPECT_CALL(message_handler, Received(Evt{"two"}));
- EXPECT_CALL(message_handler, Received(Resp{"three"}));
+ Write(Message{Req{1, "one", std::nullopt}}, Message{Evt{"two", std::nullopt}},
+ Message{Resp{3, "three"}});
+ EXPECT_CALL(message_handler, Received(Req{1, "one", std::nullopt}));
+ EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt}));
+ EXPECT_CALL(message_handler, Received(Resp{3, "three"}));
ASSERT_THAT_ERROR(Run(), Succeeded());
}
@@ -413,19 +534,22 @@ TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) {
// Use a string longer than the chunk size to ensure we split the message
// across the chunk boundary.
std::string long_str =
- std::string(JSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
- Write(Req{long_str});
- EXPECT_CALL(message_handler, Received(Req{long_str}));
+ std::string(IOTransport<int, test_protocol::Req, test_protocol::Resp,
+ test_protocol::Evt>::kReadBufferSize *
+ 2,
+ 'x');
+ Write(Req{42, long_str, std::nullopt});
+ EXPECT_CALL(message_handler, Received(Req{42, long_str, std::nullopt}));
ASSERT_THAT_ERROR(Run(), Succeeded());
}
TEST_F(JSONRPCTransportTest, ReadPartialMessage) {
- std::string message = R"({"req": "foo"})"
+ std::string message = R"({"id":42,"name":"foo","params":null})"
"\n";
std::string part1 = message.substr(0, 7);
std::string part2 = message.substr(7);
- EXPECT_CALL(message_handler, Received(Req{"foo"}));
+ EXPECT_CALL(message_handler, Received(Req{42, "foo", std::nullopt}));
ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded());
loop.AddPendingCallback(
@@ -455,20 +579,21 @@ TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) {
}
TEST_F(JSONRPCTransportTest, Write) {
- ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded());
- ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded());
- ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded());
+ ASSERT_THAT_ERROR(transport->Send(Req{11, "foo", std::nullopt}), Succeeded());
+ ASSERT_THAT_ERROR(transport->Send(Resp{14, "bar"}), Succeeded());
+ ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded());
output.CloseWriteFileDescriptor();
char buf[1024];
Expected<size_t> bytes_read =
output.Read(buf, sizeof(buf), std::chrono::milliseconds(1));
ASSERT_THAT_EXPECTED(bytes_read, Succeeded());
- ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})"
- "\n"
- R"({"resp":"bar"})"
- "\n"
- R"({"evt":"baz"})"
- "\n"));
+ ASSERT_EQ(StringRef(buf, *bytes_read),
+ StringRef(R"({"id":11,"name":"foo","params":null})"
+ "\n"
+ R"({"id":14,"result":"bar"})"
+ "\n"
+ R"({"name":"baz","params":null})"
+ "\n"));
}
TEST_F(JSONRPCTransportTest, InvalidTransport) {
@@ -477,4 +602,59 @@ TEST_F(JSONRPCTransportTest, InvalidTransport) {
FailedWithMessage("IO object is not valid."));
}
+// Out-bound binding request handler.
+TEST_F(TestTransportBinder, OutBoundRequests) {
+ auto addFn = binder->bind<MyFnResult, MyFnParams>("add");
+ addFn(MyFnParams{1, 2}, [](Expected<MyFnResult> result) {
+ EXPECT_THAT_EXPECTED(result, Succeeded());
+ EXPECT_EQ(result->c, 3);
+ });
+ EXPECT_CALL(remote, Received(Req{1, "add", MyFnParams{1, 2}}));
+ // Queue a reply that will be sent during 'Run'.
+ EXPECT_THAT_ERROR(from_remote->Send(Resp{1, toJSON(MyFnResult{3})}),
+ Succeeded());
+ Run();
+}
+
+// In-bound binding request handler.
+TEST_F(TestTransportBinder, InBoundRequests) {
+ binder->bind<MyFnResult, MyFnParams>(
+ "add",
+ [](const int captured_param,
+ const MyFnParams ¶ms) -> Expected<MyFnResult> {
+ return MyFnResult{params.a + params.b + captured_param};
+ },
+ 2);
+ EXPECT_THAT_ERROR(from_remote->Send(Req{2, "add", MyFnParams{3, 4}}),
+ Succeeded());
+ EXPECT_CALL(remote, Received(Resp{2, MyFnResult{9}}));
+ Run();
+}
+
+// Out-bound binding event handler.
+TEST_F(TestTransportBinder, OutBoundEvents) {
+ auto emitEvent = binder->bind<MyFnParams>("evt");
+ emitEvent(MyFnParams{1, 2});
+ EXPECT_CALL(remote, Received(Evt{"evt", MyFnParams{1, 2}}));
+ Run();
+}
+
+// In-bound binding event handler.
+TEST_F(TestTransportBinder, InBoundEvents) {
+ bool called = false;
+ binder->bind<MyFnParams>(
+ "evt",
+ [&](const int captured_arg, const MyFnParams ¶ms) {
+ EXPECT_EQ(captured_arg, 42);
+ EXPECT_EQ(params.a, 3);
+ EXPECT_EQ(params.b, 4);
+ called = true;
+ },
+ 42);
+ EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", MyFnParams{3, 4}}),
+ Succeeded());
+ Run();
+ EXPECT_TRUE(called);
+}
+
#endif
diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
index f686255c6d41d..0958af87a9402 100644
--- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
+++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
@@ -6,9 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "ProtocolMCPTestUtilities.h"
+#include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep
#include "TestingSupport/Host/JSONTransportTestUtilities.h"
-#include "TestingSupport/Host/PipeTestUtilities.h"
#include "TestingSupport/SubsystemRAII.h"
#include "lldb/Host/FileSystem.h"
#include "lldb/Host/HostInfo.h"
@@ -28,20 +27,21 @@
#include "llvm/Testing/Support/Error.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include <chrono>
-#include <condition_variable>
+#include <future>
+#include <memory>
+#include <optional>
using namespace llvm;
using namespace lldb;
using namespace lldb_private;
using namespace lldb_protocol::mcp;
+using testing::_;
namespace {
-class TestServer : public Server {
-public:
- using Server::Server;
-};
+template <typename T> Response make_response(T &&result, Id id = 1) {
+ return Response{id, std::forward<T>(result)};
+}
/// Test tool that returns it argument as text.
class TestTool : public Tool {
@@ -118,175 +118,209 @@ class FailTool : public Tool {
}
};
-class ProtocolServerMCPTest : public PipePairTest {
+class TestServer : public Server {
+public:
+ using Server::Bind;
+ using Server::Server;
+};
+
+using Transport = TestTransport<int64_t, lldb_protocol::mcp::Request,
+ lldb_protocol::mcp::Response,
+ lldb_protocol::mcp::Notification>;
+
+class ProtocolServerMCPTest : public testing::Test {
public:
SubsystemRAII<FileSystem, HostInfo, Socket> subsystems;
- std::unique_ptr<lldb_protocol::mcp::Transport> transport_up;
- std::unique_ptr<TestServer> server_up;
MainLoop loop;
- MockMessageHandler<Request, Response, Notification> message_handler;
+ lldb_private::MainLoop::ReadHandleUP handles[2];
- llvm::Error Write(llvm::StringRef message) {
- llvm::Expected<json::Value> value = json::parse(message);
- if (!value)
- return value.takeError();
- return transport_up->Write(*value);
- }
+ std::unique_ptr<Transport> to_server;
+ Transport::BinderUP binder;
+ std::unique_ptr<TestServer> server_up;
- llvm::Error Write(json::Value value) { return transport_up->Write(value); }
+ std::unique_ptr<Transport> to_client;
+ MockMessageHandler<int64_t, Request, Response, Notification> client;
- /// Run the transport MainLoop and return any messages received.
- llvm::Error
- Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) {
- loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); },
- timeout);
- auto handle = transport_up->RegisterMessageHandler(loop, message_handler);
- if (!handle)
- return handle.takeError();
+ std::vector<std::string> logged_messages;
- return server_up->Run();
+ /// Runs the MainLoop a single time, executing any pending callbacks.
+ void Run() {
+ loop.AddPendingCallback(
+ [](MainLoopBase &loop) { loop.RequestTermination(); });
+ EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded());
}
void SetUp() override {
- PipePairTest::SetUp();
-
- transport_up = std::make_unique<lldb_protocol::mcp::Transport>(
- std::make_shared<NativeFile>(input.GetReadFileDescriptor(),
- File::eOpenOptionReadOnly,
- NativeFile::Unowned),
- std::make_shared<NativeFile>(output.GetWriteFileDescriptor(),
- File::eOpenOptionWriteOnly,
- NativeFile::Unowned));
+ std::tie(to_client, to_server) = Transport::createPair();
server_up = std::make_unique<TestServer>(
"lldb-mcp", "0.1.0",
- std::make_unique<lldb_protocol::mcp::Transport>(
- std::make_shared<NativeFile>(output.GetReadFileDescriptor(),
- File::eOpenOptionReadOnly,
- NativeFile::Unowned),
- std::make_shared<NativeFile>(input.GetWriteFileDescriptor(),
- File::eOpenOptionWriteOnly,
- NativeFile::Unowned)),
- loop);
+ [this](StringRef msg) { logged_messages.push_back(msg.str()); });
+ binder = server_up->Bind(*to_client);
+ auto server_handle = to_server->RegisterMessageHandler(loop, *binder);
+ EXPECT_THAT_EXPECTED(server_handle, Succeeded());
+ binder->error([](llvm::Error error) {
+ llvm::errs() << formatv("Server transport error: {0}", error);
+ });
+ handles[0] = std::move(*server_handle);
+
+ auto client_handle = to_client->RegisterMessageHandler(loop, client);
+ EXPECT_THAT_EXPECTED(client_handle, Succeeded());
+ handles[1] = std::move(*client_handle);
+ }
+
+ template <typename Result, typename Params>
+ Expected<json::Value> Call(StringRef method, const Params ¶ms) {
+ std::promise<Response> promised_result;
+ Request req = make_request<int64_t, lldb_protocol::mcp::Request>(
+ /*id=*/1, method, toJSON(params));
+ EXPECT_THAT_ERROR(to_server->Send(req), Succeeded());
+ EXPECT_CALL(client, Received(testing::An<const Response &>()))
+ .WillOnce(
+ [&](const Response &resp) { promised_result.set_value(resp); });
+ Run();
+ Response resp = promised_result.get_future().get();
+ return toJSON(resp);
+ }
+
+ template <typename Result>
+ Expected<json::Value>
+ Capture(llvm::unique_function<void(Reply<Result>)> &fn) {
+ std::promise<llvm::Expected<Result>> promised_result;
+ fn([&promised_result](llvm::Expected<Result> result) {
+ promised_result.set_value(std::move(result));
+ });
+ Run();
+ llvm::Expected<Result> result = promised_result.get_future().get();
+ if (!result)
+ return result.takeError();
+ return toJSON(*result);
+ }
+
+ template <typename Result, typename Params>
+ Expected<json::Value>
+ Capture(llvm::unique_function<void(const Params &, Reply<Result>)> &fn,
+ const Params ¶ms) {
+ std::promise<llvm::Expected<Result>> promised_result;
+ fn(params, [&promised_result](llvm::Expected<Result> result) {
+ promised_result.set_value(std::move(result));
+ });
+ Run();
+ llvm::Expected<Result> result = promised_result.get_future().get();
+ if (!result)
+ return result.takeError();
+ return toJSON(*result);
}
};
template <typename T>
-Request make_request(StringLiteral method, T &¶ms, Id id = 1) {
- return Request{id, method.str(), toJSON(std::forward<T>(params))};
-}
-
-template <typename T> Response make_response(T &&result, Id id = 1) {
- return Response{id, std::forward<T>(result)};
+inline testing::internal::EqMatcher<llvm::json::Value> HasJSON(T x) {
+ return testing::internal::EqMatcher<llvm::json::Value>(toJSON(x));
}
} // namespace
TEST_F(ProtocolServerMCPTest, Initialization) {
- Request request = make_request(
- "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05",
- /*capabilities=*/{},
- /*clientInfo=*/{"lldb-unit", "0.1.0"}});
- Response response = make_response(
- InitializeResult{/*protocolVersion=*/"2024-11-05",
- /*capabilities=*/{/*supportsToolsList=*/true},
- /*serverInfo=*/{"lldb-mcp", "0.1.0"}});
-
- ASSERT_THAT_ERROR(Write(request), Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(
+ (Call<InitializeResult, InitializeParams>(
+ "initialize",
+ InitializeParams{/*protocolVersion=*/"2024-11-05",
+ /*capabilities=*/{},
+ /*clientInfo=*/{"lldb-unit", "0.1.0"}})),
+ HasValue(make_response(
+ InitializeResult{/*protocolVersion=*/"2024-11-05",
+ /*capabilities=*/
+ {
+ /*supportsToolsList=*/true,
+ /*supportsResourcesList=*/true,
+ },
+ /*serverInfo=*/{"lldb-mcp", "0.1.0"}})));
}
TEST_F(ProtocolServerMCPTest, ToolsList) {
server_up->AddTool(std::make_unique<TestTool>("test", "test tool"));
- Request request = make_request("tools/list", Void{}, /*id=*/"one");
-
ToolDefinition test_tool;
test_tool.name = "test";
test_tool.description = "test tool";
test_tool.inputSchema = json::Object{{"type", "object"}};
- Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one");
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(Call<ListToolsResult>("tools/list", Void{}),
+ HasValue(make_response(ListToolsResult{{test_tool}})));
}
TEST_F(ProtocolServerMCPTest, ResourcesList) {
server_up->AddResourceProvider(std::make_unique<TestResourceProvider>());
- Request request = make_request("resources/list", Void{});
- Response response = make_response(ListResourcesResult{
- {{/*uri=*/"lldb://foo/bar", /*name=*/"name",
- /*description=*/"description", /*mimeType=*/"application/json"}}});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(Call<ListResourcesResult>("resources/list", Void{}),
+ HasValue(make_response(ListResourcesResult{{
+ {
+ /*uri=*/"lldb://foo/bar",
+ /*name=*/"name",
+ /*description=*/"description",
+ /*mimeType=*/"application/json",
+ },
+ }})));
}
TEST_F(ProtocolServerMCPTest, ToolsCall) {
server_up->AddTool(std::make_unique<TestTool>("test", "test tool"));
- Request request = make_request(
- "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{
- {"arguments", "foo"},
- {"debugger_id", 0},
- }});
- Response response = make_response(CallToolResult{{{/*text=*/"foo"}}});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(
+ (Call<CallToolResult, CallToolParams>("tools/call",
+ CallToolParams{
+ /*name=*/"test",
+ /*arguments=*/
+ json::Object{
+ {"arguments", "foo"},
+ {"debugger_id", 0},
+ },
+ })),
+ HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}})));
}
TEST_F(ProtocolServerMCPTest, ToolsCallError) {
server_up->AddTool(std::make_unique<ErrorTool>("error", "error tool"));
- Request request = make_request(
- "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{
- {"arguments", "foo"},
- {"debugger_id", 0},
- }});
- Response response =
- make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError,
- /*message=*/"error"});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>(
+ "tools/call", CallToolParams{
+ /*name=*/"error",
+ /*arguments=*/
+ json::Object{
+ {"arguments", "foo"},
+ {"debugger_id", 0},
+ },
+ })),
+ HasValue(make_response(lldb_protocol::mcp::Error{
+ eErrorCodeInternalError, "error"})));
}
TEST_F(ProtocolServerMCPTest, ToolsCallFail) {
server_up->AddTool(std::make_unique<FailTool>("fail", "fail tool"));
- Request request = make_request(
- "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{
- {"arguments", "foo"},
- {"debugger_id", 0},
- }});
- Response response =
- make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>(
+ "tools/call", CallToolParams{
+ /*name=*/"fail",
+ /*arguments=*/
+ json::Object{
+ {"arguments", "foo"},
+ {"debugger_id", 0},
+ },
+ })),
+ HasValue(make_response(CallToolResult{
+ {{/*text=*/"failed"}},
+ /*isError=*/true,
+ })));
}
TEST_F(ProtocolServerMCPTest, NotificationInitialized) {
- bool handler_called = false;
- std::condition_variable cv;
-
- server_up->AddNotificationHandler(
- "notifications/initialized",
- [&](const Notification ¬ification) { handler_called = true; });
- llvm::StringLiteral request =
- R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json";
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_THAT_ERROR(Run(), Succeeded());
- EXPECT_TRUE(handler_called);
+ EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{
+ "notifications/initialized",
+ std::nullopt,
+ }),
+ Succeeded());
+ Run();
+ EXPECT_THAT(logged_messages,
+ testing::Contains("MCP initialization complete"));
}
diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h
index 5a9eb8e59f2b6..4dbcd614e400b 100644
--- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h
+++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h
@@ -6,19 +6,105 @@
//
//===----------------------------------------------------------------------===//
-#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H
-#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H
+#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H
+#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H
+#include "lldb/Host/FileSystem.h"
#include "lldb/Host/JSONTransport.h"
+#include "lldb/Host/MainLoop.h"
+#include "lldb/Utility/FileSpec.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Testing/Support/Error.h"
#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <cstddef>
+#include <memory>
+#include <utility>
-template <typename Req, typename Resp, typename Evt>
+template <typename Id, typename Req, typename Resp, typename Evt>
+class TestTransport final
+ : public lldb_private::JSONTransport<Id, Req, Resp, Evt> {
+public:
+ using MessageHandler =
+ typename lldb_private::JSONTransport<Id, Req, Resp, Evt>::MessageHandler;
+
+ static std::pair<std::unique_ptr<TestTransport<Id, Req, Resp, Evt>>,
+ std::unique_ptr<TestTransport<Id, Req, Resp, Evt>>>
+ createPair() {
+ std::unique_ptr<TestTransport<Id, Req, Resp, Evt>> transports[2] = {
+ std::make_unique<TestTransport<Id, Req, Resp, Evt>>(),
+ std::make_unique<TestTransport<Id, Req, Resp, Evt>>()};
+ return std::make_pair(std::move(transports[0]), std::move(transports[1]));
+ }
+
+ explicit TestTransport() {
+ llvm::Expected<lldb::FileUP> dummy_file =
+ lldb_private::FileSystem::Instance().Open(
+ lldb_private::FileSpec(lldb_private::FileSystem::DEV_NULL),
+ lldb_private::File::eOpenOptionReadWrite);
+ EXPECT_THAT_EXPECTED(dummy_file, llvm::Succeeded());
+ m_dummy_file = std::move(*dummy_file);
+ }
+
+ llvm::Error Send(const Evt &evt) override {
+ EXPECT_TRUE(m_loop && m_handler)
+ << "Send called before RegisterMessageHandler";
+ m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) {
+ m_handler->Received(evt);
+ });
+ return llvm::Error::success();
+ }
+
+ llvm::Error Send(const Req &req) override {
+ EXPECT_TRUE(m_loop && m_handler)
+ << "Send called before RegisterMessageHandler";
+ m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) {
+ m_handler->Received(req);
+ });
+ return llvm::Error::success();
+ }
+
+ llvm::Error Send(const Resp &resp) override {
+ EXPECT_TRUE(m_loop && m_handler)
+ << "Send called before RegisterMessageHandler";
+ m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) {
+ m_handler->Received(resp);
+ });
+ return llvm::Error::success();
+ }
+
+ llvm::Expected<lldb_private::MainLoop::ReadHandleUP>
+ RegisterMessageHandler(lldb_private::MainLoop &loop,
+ MessageHandler &handler) override {
+ if (!m_loop)
+ m_loop = &loop;
+ if (!m_handler)
+ m_handler = &handler;
+ lldb_private::Status status;
+ auto handle = loop.RegisterReadObject(
+ m_dummy_file, [](lldb_private::MainLoopBase &) {}, status);
+ if (status.Fail())
+ return status.takeError();
+ return handle;
+ }
+
+protected:
+ void Log(llvm::StringRef message) override {};
+
+private:
+ lldb_private::MainLoop *m_loop = nullptr;
+ MessageHandler *m_handler = nullptr;
+ // Dummy file for registering with the MainLoop.
+ lldb::FileSP m_dummy_file = nullptr;
+};
+
+template <typename Id, typename Req, typename Resp, typename Evt>
class MockMessageHandler final
- : public lldb_private::Transport<Req, Resp, Evt>::MessageHandler {
+ : public lldb_private::JSONTransport<Id, Req, Resp, Evt>::MessageHandler {
public:
- MOCK_METHOD(void, Received, (const Evt &), (override));
MOCK_METHOD(void, Received, (const Req &), (override));
MOCK_METHOD(void, Received, (const Resp &), (override));
+ MOCK_METHOD(void, Received, (const Evt &), (override));
MOCK_METHOD(void, OnError, (llvm::Error), (override));
MOCK_METHOD(void, OnClosed, (), (override));
};
More information about the lldb-commits
mailing list