[Lldb-commits] [lldb] [lldb] Creating a new Binder helper for JSONRPC transport. (PR #155315)
John Harrison via lldb-commits
lldb-commits at lists.llvm.org
Mon Aug 25 14:50:20 PDT 2025
https://github.com/ashgti created https://github.com/llvm/llvm-project/pull/155315
The `lldb_protocol::mcp::Binder` class is used to craft bindings between requests and notifications to specific handlers.
This supports both incoming and outgoing handlers that bind these functions to a MessageHandler and generates encoding/decoding helpers for each call.
For example, see the `lldb_protocol::mcp::Server` class that has been greatly simplified.
>From 81643e70e88aa9cb91932071336ae817b1b2926d Mon Sep 17 00:00:00 2001
From: John Harrison <harjohn at google.com>
Date: Mon, 25 Aug 2025 14:40:08 -0700
Subject: [PATCH] [lldb] Creating a new Binder helper for JSONRPC transport.
The `lldb_protocol::mcp::Binder` class is used to craft bindings between requests and notifications to specific handlers.
This supports both incoming and outgoing handlers that bind these functions to a MessageHandler and generates encoding/decoding helpers for each call.
For example, see the `lldb_protocol::mcp::Server` class that has been greatly simplified.
---
lldb/include/lldb/Protocol/MCP/Binder.h | 351 ++++++++++++++++++
lldb/include/lldb/Protocol/MCP/Protocol.h | 173 ++++++++-
lldb/include/lldb/Protocol/MCP/Resource.h | 2 +-
lldb/include/lldb/Protocol/MCP/Server.h | 74 ++--
lldb/include/lldb/Protocol/MCP/Tool.h | 9 +-
lldb/include/lldb/Protocol/MCP/Transport.h | 50 +++
.../Protocol/MCP/ProtocolServerMCP.cpp | 20 +-
.../Plugins/Protocol/MCP/ProtocolServerMCP.h | 4 +-
lldb/source/Plugins/Protocol/MCP/Resource.cpp | 10 +-
lldb/source/Plugins/Protocol/MCP/Resource.h | 6 +-
lldb/source/Plugins/Protocol/MCP/Tool.cpp | 26 +-
lldb/source/Plugins/Protocol/MCP/Tool.h | 7 +-
lldb/source/Protocol/MCP/Binder.cpp | 139 +++++++
lldb/source/Protocol/MCP/CMakeLists.txt | 3 +
lldb/source/Protocol/MCP/Protocol.cpp | 159 +++++++-
lldb/source/Protocol/MCP/Server.cpp | 255 +++----------
lldb/source/Protocol/MCP/Transport.cpp | 113 ++++++
lldb/unittests/Protocol/ProtocolMCPTest.cpp | 10 +-
.../ProtocolServer/ProtocolMCPServerTest.cpp | 78 ++--
19 files changed, 1160 insertions(+), 329 deletions(-)
create mode 100644 lldb/include/lldb/Protocol/MCP/Binder.h
create mode 100644 lldb/include/lldb/Protocol/MCP/Transport.h
create mode 100644 lldb/source/Protocol/MCP/Binder.cpp
create mode 100644 lldb/source/Protocol/MCP/Transport.cpp
diff --git a/lldb/include/lldb/Protocol/MCP/Binder.h b/lldb/include/lldb/Protocol/MCP/Binder.h
new file mode 100644
index 0000000000000..f9cebd940bfcb
--- /dev/null
+++ b/lldb/include/lldb/Protocol/MCP/Binder.h
@@ -0,0 +1,351 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLDB_PROTOCOL_MCP_BINDER_H
+#define LLDB_PROTOCOL_MCP_BINDER_H
+
+#include "lldb/Protocol/MCP/MCPError.h"
+#include "lldb/Protocol/MCP/Protocol.h"
+#include "lldb/Protocol/MCP/Transport.h"
+#include "lldb/Utility/Status.h"
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/JSON.h"
+#include <functional>
+#include <future>
+#include <memory>
+#include <mutex>
+#include <optional>
+
+namespace lldb_protocol::mcp {
+
+template <typename T> using Callback = llvm::unique_function<T>;
+
+template <typename T>
+using Reply = llvm::unique_function<void(llvm::Expected<T>)>;
+template <typename Params, typename Result>
+using OutgoingRequest =
+ llvm::unique_function<void(const Params &, Reply<Result>)>;
+template <typename Params>
+using OutgoingNotification = llvm::unique_function<void(const Params &)>;
+
+template <typename Params, typename Result>
+llvm::Expected<Result> AsyncInvoke(lldb_private::MainLoop &loop,
+ OutgoingRequest<Params, Result> &fn,
+ const Params ¶ms) {
+ std::promise<llvm::Expected<Result>> result_promise;
+ std::future<llvm::Expected<Result>> result_future =
+ result_promise.get_future();
+ std::thread thr([&loop, &fn, params,
+ result_promise = std::move(result_promise)]() mutable {
+ fn(params, [&loop, &result_promise](llvm::Expected<Result> result) mutable {
+ result_promise.set_value(std::move(result));
+ loop.AddPendingCallback(
+ [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
+ });
+ if (llvm::Error error = loop.Run().takeError())
+ result_promise.set_value(std::move(error));
+ });
+ thr.join();
+ return result_future.get();
+}
+
+/// Binder collects a table of functions that handle calls.
+///
+/// The wrapper takes care of parsing/serializing responses.
+class Binder {
+public:
+ explicit Binder(MCPTransport *transport) : m_handlers(transport) {}
+
+ Binder(const Binder &) = delete;
+ Binder &operator=(const Binder &) = delete;
+
+ /// Bind a handler on transport disconnect.
+ template <typename ThisT, typename... ExtraArgs>
+ void disconnected(void (ThisT::*handler)(MCPTransport *), ThisT *_this,
+ ExtraArgs... extra_args) {
+ m_handlers.m_disconnect_handler =
+ std::bind(handler, _this, std::placeholders::_1,
+ std::forward<ExtraArgs>(extra_args)...);
+ }
+
+ /// Bind a handler on error when communicating with the transport.
+ template <typename ThisT, typename... ExtraArgs>
+ void error(void (ThisT::*handler)(MCPTransport *, llvm::Error), ThisT *_this,
+ ExtraArgs... extra_args) {
+ m_handlers.m_error_handler =
+ std::bind(handler, _this, std::placeholders::_1, std::placeholders::_2,
+ std::forward<ExtraArgs>(extra_args)...);
+ }
+
+ /// Bind a handler for a request.
+ /// e.g. Bind.request("peek", this, &ThisModule::peek);
+ /// Handler should be e.g. Expected<PeekResult> peek(const PeekParams&);
+ /// PeekParams must be JSON parsable and PeekResult must be serializable.
+ template <typename Result, typename Params, typename ThisT,
+ typename... ExtraArgs>
+ void request(llvm::StringLiteral method,
+ llvm::Expected<Result> (ThisT::*fn)(const Params &,
+ ExtraArgs...),
+ ThisT *_this, ExtraArgs... extra_args) {
+ assert(m_handlers.m_request_handlers.find(method) ==
+ m_handlers.m_request_handlers.end() &&
+ "request already bound");
+ std::function<llvm::Expected<Result>(const Params &)> handler =
+ std::bind(fn, _this, std::placeholders::_1,
+ std::forward<ExtraArgs>(extra_args)...);
+ m_handlers.m_request_handlers[method] =
+ [method, handler](const Request &req,
+ llvm::unique_function<void(const Response &)> reply) {
+ Params params;
+ llvm::json::Path::Root root(method);
+ if (!fromJSON(req.params, params, root)) {
+ reply(Response{0, Error{eErrorCodeInvalidParams,
+ "invalid params for " + method.str() +
+ ": " + llvm::toString(root.getError()),
+ std::nullopt}});
+ return;
+ }
+ llvm::Expected<Result> result = handler(params);
+ if (llvm::Error error = result.takeError()) {
+ Error protocol_error;
+ llvm::handleAllErrors(
+ std::move(error),
+ [&](const MCPError &err) {
+ protocol_error = err.toProtocolError();
+ },
+ [&](const llvm::ErrorInfoBase &err) {
+ protocol_error.code = MCPError::kInternalError;
+ protocol_error.message = err.message();
+ });
+ reply(Response{0, protocol_error});
+ return;
+ }
+
+ reply(Response{0, *result});
+ };
+ }
+
+ /// Bind a handler for an async request.
+ /// e.g. Bind.asyncRequest("peek", this, &ThisModule::peek);
+ /// Handler should be e.g. `void peek(const PeekParams&,
+ /// Reply<Expected<PeekResult>>);` PeekParams must be JSON parsable and
+ /// PeekResult must be serializable.
+ template <typename Result, typename Params, typename... ExtraArgs>
+ void asyncRequest(
+ llvm::StringLiteral method,
+ std::function<void(const Params &, ExtraArgs..., Reply<Result>)> fn,
+ ExtraArgs... extra_args) {
+ assert(m_handlers.m_request_handlers.find(method) ==
+ m_handlers.m_request_handlers.end() &&
+ "request already bound");
+ std::function<void(const Params &, Reply<Result>)> handler = std::bind(
+ fn, std::placeholders::_1, std::forward<ExtraArgs>(extra_args)...,
+ std::placeholders::_2);
+ m_handlers.m_request_handlers[method] =
+ [method, handler](const Request &req,
+ Callback<void(const Response &)> reply) {
+ Params params;
+ llvm::json::Path::Root root(method);
+ if (!fromJSON(req.params, params, root)) {
+ reply(Response{0, Error{eErrorCodeInvalidParams,
+ "invalid params for " + method.str() +
+ ": " + llvm::toString(root.getError()),
+ std::nullopt}});
+ return;
+ }
+
+ handler(params, [reply = std::move(reply)](
+ llvm::Expected<Result> result) mutable {
+ if (llvm::Error error = result.takeError()) {
+ Error protocol_error;
+ llvm::handleAllErrors(
+ std::move(error),
+ [&](const MCPError &err) {
+ protocol_error = err.toProtocolError();
+ },
+ [&](const llvm::ErrorInfoBase &err) {
+ protocol_error.code = MCPError::kInternalError;
+ protocol_error.message = err.message();
+ });
+ reply(Response{0, protocol_error});
+ return;
+ }
+
+ reply(Response{0, toJSON(*result)});
+ });
+ };
+ }
+ template <typename Result, typename Params, typename ThisT,
+ typename... ExtraArgs>
+ void asyncRequest(llvm::StringLiteral method,
+ void (ThisT::*fn)(const Params &, ExtraArgs...,
+ Reply<Result>),
+ ThisT *_this, ExtraArgs... extra_args) {
+ assert(m_handlers.m_request_handlers.find(method) ==
+ m_handlers.m_request_handlers.end() &&
+ "request already bound");
+ std::function<void(const Params &, Reply<Result>)> handler = std::bind(
+ fn, _this, std::placeholders::_1,
+ std::forward<ExtraArgs>(extra_args)..., std::placeholders::_2);
+ m_handlers.m_request_handlers[method] =
+ [method, handler](const Request &req,
+ Callback<void(const Response &)> reply) {
+ Params params;
+ llvm::json::Path::Root root;
+ if (!fromJSON(req.params, params, root)) {
+ reply(Response{0, Error{eErrorCodeInvalidParams,
+ "invalid params for " + method.str(),
+ std::nullopt}});
+ return;
+ }
+
+ handler(params, [reply = std::move(reply)](
+ llvm::Expected<Result> result) mutable {
+ if (llvm::Error error = result.takeError()) {
+ Error protocol_error;
+ llvm::handleAllErrors(
+ std::move(error),
+ [&](const MCPError &err) {
+ protocol_error = err.toProtocolError();
+ },
+ [&](const llvm::ErrorInfoBase &err) {
+ protocol_error.code = MCPError::kInternalError;
+ protocol_error.message = err.message();
+ });
+ reply(Response{0, protocol_error});
+ return;
+ }
+
+ reply(Response{0, toJSON(*result)});
+ });
+ };
+ }
+
+ /// Bind a handler for a notification.
+ /// e.g. Bind.notification("peek", this, &ThisModule::peek);
+ /// Handler should be e.g. void peek(const PeekParams&);
+ /// PeekParams must be JSON parsable.
+ template <typename Params, typename ThisT, typename... ExtraArgs>
+ void notification(llvm::StringLiteral method,
+ void (ThisT::*fn)(const Params &, ExtraArgs...),
+ ThisT *_this, ExtraArgs... extra_args) {
+ std::function<void(const Params &)> handler =
+ std::bind(fn, _this, std::placeholders::_1,
+ std::forward<ExtraArgs>(extra_args)...);
+ m_handlers.m_notification_handlers[method] =
+ [handler](const Notification ¬e) {
+ Params params;
+ llvm::json::Path::Root root;
+ if (!fromJSON(note.params, params, root))
+ return; // FIXME: log error?
+
+ handler(params);
+ };
+ }
+ template <typename Params>
+ void notification(llvm::StringLiteral method,
+ std::function<void(const Params &)> handler) {
+ assert(m_handlers.m_notification_handlers.find(method) ==
+ m_handlers.m_notification_handlers.end() &&
+ "notification already bound");
+ m_handlers.m_notification_handlers[method] =
+ [handler = std::move(handler)](const Notification ¬e) {
+ Params params;
+ llvm::json::Path::Root root;
+ if (!fromJSON(note.params, params, root))
+ return; // FIXME: log error?
+
+ handler(params);
+ };
+ }
+
+ /// Bind a function object to be used for outgoing requests.
+ /// e.g. OutgoingRequest<Params, Result> Edit = Bind.outgoingRequest("edit");
+ /// Params must be JSON-serializable, Result must be parsable.
+ template <typename Params, typename Result>
+ OutgoingRequest<Params, Result> outgoingRequest(llvm::StringLiteral method) {
+ return [this, method](const Params ¶ms, Reply<Result> reply) {
+ Request request;
+ request.method = method;
+ request.params = toJSON(params);
+ m_handlers.Send(request, [reply = std::move(reply)](
+ const Response &resp) mutable {
+ if (const lldb_protocol::mcp::Error *err =
+ std::get_if<lldb_protocol::mcp::Error>(&resp.result)) {
+ reply(llvm::make_error<MCPError>(err->message, err->code));
+ return;
+ }
+ Result result;
+ llvm::json::Path::Root root;
+ if (!fromJSON(std::get<llvm::json::Value>(resp.result), result, root)) {
+ reply(llvm::make_error<MCPError>("parsing response failed: " +
+ llvm::toString(root.getError())));
+ return;
+ }
+ reply(result);
+ });
+ };
+ }
+
+ /// Bind a function object to be used for outgoing notifications.
+ /// e.g. OutgoingNotification<LogParams> Log = Bind.outgoingMethod("log");
+ /// LogParams must be JSON-serializable.
+ template <typename Params>
+ OutgoingNotification<Params>
+ outgoingNotification(llvm::StringLiteral method) {
+ return [this, method](const Params ¶ms) {
+ Notification note;
+ note.method = method;
+ note.params = toJSON(params);
+ m_handlers.Send(note);
+ };
+ }
+
+ operator MCPTransport::MessageHandler &() { return m_handlers; }
+
+private:
+ class RawHandler final : public MCPTransport::MessageHandler {
+ public:
+ explicit RawHandler(MCPTransport *transport);
+
+ void Received(const Notification ¬e) override;
+ void Received(const Request &req) override;
+ void Received(const Response &resp) override;
+ void OnError(llvm::Error err) override;
+ void OnClosed() override;
+
+ void Send(const Request &req,
+ Callback<void(const Response &)> response_handler);
+ void Send(const Notification ¬e);
+ void Send(const Response &resp);
+
+ friend class Binder;
+
+ private:
+ std::recursive_mutex m_mutex;
+ MCPTransport *m_transport;
+ int m_seq = 0;
+ std::map<Id, Callback<void(const Response &)>> m_pending_responses;
+ llvm::StringMap<
+ Callback<void(const Request &, Callback<void(const Response &)>)>>
+ m_request_handlers;
+ llvm::StringMap<Callback<void(const Notification &)>>
+ m_notification_handlers;
+ Callback<void(MCPTransport *)> m_disconnect_handler;
+ Callback<void(MCPTransport *, llvm::Error)> m_error_handler;
+ };
+
+ RawHandler m_handlers;
+};
+using BinderUP = std::unique_ptr<Binder>;
+
+} // namespace lldb_protocol::mcp
+
+#endif
diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h
index 49f9490221755..d21a5ef85ece6 100644
--- a/lldb/include/lldb/Protocol/MCP/Protocol.h
+++ b/lldb/include/lldb/Protocol/MCP/Protocol.h
@@ -14,10 +14,12 @@
#ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H
#define LLDB_PROTOCOL_MCP_PROTOCOL_H
+#include "lldb/lldb-types.h"
#include "llvm/Support/JSON.h"
#include <optional>
#include <string>
#include <variant>
+#include <vector>
namespace lldb_protocol::mcp {
@@ -43,6 +45,12 @@ llvm::json::Value toJSON(const Request &);
bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path);
bool operator==(const Request &, const Request &);
+enum ErrorCode : signed {
+ eErrorCodeMethodNotFound = -32601,
+ eErrorCodeInvalidParams = -32602,
+ eErrorCodeInternalError = -32000,
+};
+
struct Error {
/// The error type that occurred.
int64_t code = 0;
@@ -147,6 +155,14 @@ struct Resource {
llvm::json::Value toJSON(const Resource &);
bool fromJSON(const llvm::json::Value &, Resource &, llvm::json::Path);
+/// The server’s response to a resources/list request from the client.
+struct ResourcesListResult {
+ std::vector<Resource> resources;
+};
+llvm::json::Value toJSON(const ResourcesListResult &);
+bool fromJSON(const llvm::json::Value &, ResourcesListResult &,
+ llvm::json::Path);
+
/// The contents of a specific resource or sub-resource.
struct ResourceContents {
/// The URI of this resource.
@@ -163,13 +179,23 @@ struct ResourceContents {
llvm::json::Value toJSON(const ResourceContents &);
bool fromJSON(const llvm::json::Value &, ResourceContents &, llvm::json::Path);
+/// Sent from the client to the server, to read a specific resource URI.
+struct ResourcesReadParams {
+ /// The URI of the resource to read. The URI can use any protocol; it is up to
+ /// the server how to interpret it.
+ std::string URI;
+};
+llvm::json::Value toJSON(const ResourcesReadParams &);
+bool fromJSON(const llvm::json::Value &, ResourcesReadParams &,
+ llvm::json::Path);
+
/// The server's response to a resources/read request from the client.
-struct ResourceResult {
+struct ResourcesReadResult {
std::vector<ResourceContents> contents;
};
-
-llvm::json::Value toJSON(const ResourceResult &);
-bool fromJSON(const llvm::json::Value &, ResourceResult &, llvm::json::Path);
+llvm::json::Value toJSON(const ResourcesReadResult &);
+bool fromJSON(const llvm::json::Value &, ResourcesReadResult &,
+ llvm::json::Path);
/// Text provided to or from an LLM.
struct TextContent {
@@ -204,6 +230,145 @@ bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path);
using ToolArguments = std::variant<std::monostate, llvm::json::Value>;
+/// Describes the name and version of an MCP implementation, with an optional
+/// title for UI representation.
+///
+/// see
+/// https://modelcontextprotocol.io/specification/2025-06-18/schema#implementation
+struct Implementation {
+ /// Intended for programmatic or logical use, but used as a display name in
+ /// past specs or fallback (if title isn’t present).
+ std::string name;
+
+ /// Intended for UI and end-user contexts — optimized to be human-readable and
+ /// easily understood, even by those unfamiliar with domain-specific
+ /// terminology.
+ ///
+ /// If not provided, the name should be used for display (except for Tool,
+ /// where annotations.title should be given precedence over using name, if
+ /// present).
+ std::string title;
+
+ std::string version;
+};
+llvm::json::Value toJSON(const Implementation &);
+bool fromJSON(const llvm::json::Value &, Implementation &, llvm::json::Path);
+
+/// Capabilities a client may support. Known capabilities are defined here, in
+/// this schema, but this is not a closed set: any client can define its own,
+/// additional capabilities.
+struct ClientCapabilities {};
+llvm::json::Value toJSON(const ClientCapabilities &);
+bool fromJSON(const llvm::json::Value &, ClientCapabilities &,
+ llvm::json::Path);
+
+/// Capabilities that a server may support. Known capabilities are defined here,
+/// in this schema, but this is not a closed set: any server can define its own,
+/// additional capabilities.
+struct ServerCapabilities {
+ bool supportsToolsList = false;
+ bool supportsResourcesList = false;
+ bool supportsResourcesSubscribe = false;
+
+ /// Utilities.
+ bool supportsCompletions = false;
+ bool supportsLogging = false;
+};
+llvm::json::Value toJSON(const ServerCapabilities &);
+bool fromJSON(const llvm::json::Value &, ServerCapabilities &,
+ llvm::json::Path);
+
+/// Initialization
+
+/// This request is sent from the client to the server when it first connects,
+/// asking it to begin initialization.
+///
+/// @category initialize
+struct InitializeParams {
+ /// The latest version of the Model Context Protocol that the client supports.
+ /// The client MAY decide to support older versions as well.
+ std::string protocolVersion;
+
+ ClientCapabilities capabilities;
+
+ Implementation clientInfo;
+};
+llvm::json::Value toJSON(const InitializeParams &);
+bool fromJSON(const llvm::json::Value &, InitializeParams &, llvm::json::Path);
+
+/// After receiving an initialize request from the client, the server sends this
+/// response.
+///
+/// @category initialize
+struct InitializeResult {
+ /// The version of the Model Context Protocol that the server wants to use.
+ /// This may not match the version that the client requested. If the client
+ /// cannot support this version, it MUST disconnect.
+ std::string protocolVersion;
+
+ ServerCapabilities capabilities;
+ Implementation serverInfo;
+
+ /// Instructions describing how to use the server and its features.
+ ///
+ /// This can be used by clients to improve the LLM's understanding of
+ /// available tools, resources, etc. It can be thought of like a "hint" to the
+ /// model. For example, this information MAY be added to the system prompt.
+ std::string instructions;
+};
+llvm::json::Value toJSON(const InitializeResult &);
+bool fromJSON(const llvm::json::Value &, InitializeResult &, llvm::json::Path);
+
+/// Special case parameter.
+using Void = std::monostate;
+llvm::json::Value toJSON(const Void &);
+bool fromJSON(const llvm::json::Value &, Void &, llvm::json::Path);
+
+/// The server's response to a `tools/list` request from the client.
+struct ToolsListResult {
+ std::vector<ToolDefinition> tools;
+};
+llvm::json::Value toJSON(const ToolsListResult &);
+bool fromJSON(const llvm::json::Value &, ToolsListResult &, llvm::json::Path);
+
+// FIXME: Support other content types as needed.
+using ContentBlock = TextContent;
+
+/// Used by the client to invoke a tool provided by the server.
+struct ToolsCallParams {
+ std::string name;
+ std::optional<llvm::json::Value> arguments;
+};
+llvm::json::Value toJSON(const ToolsCallParams &);
+bool fromJSON(const llvm::json::Value &, ToolsCallParams &, llvm::json::Path);
+
+/// The server’s response to a tool call.
+struct ToolsCallResult {
+ /// A list of content objects that represent the unstructured result of the
+ /// tool call.
+ std::vector<ContentBlock> content;
+
+ /// Whether the tool call ended in an error.
+ ///
+ /// If not set, this is assumed to be false (the call was successful).
+ ///
+ /// Any errors that originate from the tool SHOULD be reported inside the
+ /// result object, with `isError` set to true, not as an MCP protocol-level
+ /// error response. Otherwise, the LLM would not be able to see that an error
+ /// occurred and self-correct.
+ ///
+ /// However, any errors in finding the tool, an error indicating that the
+ /// server does not support tool calls, or any other exceptional conditions,
+ /// should be reported as an MCP error response.
+ bool isError = false;
+
+ /// An optional JSON object that represents the structured result of the tool
+ /// call.
+ std::optional<llvm::json::Value> structuredContent;
+};
+llvm::json::Value toJSON(const ToolsCallResult &);
+bool fromJSON(const llvm::json::Value &, ToolsCallResult &, llvm::json::Path);
+
} // namespace lldb_protocol::mcp
#endif
diff --git a/lldb/include/lldb/Protocol/MCP/Resource.h b/lldb/include/lldb/Protocol/MCP/Resource.h
index 4835d340cd4c6..8a3e3ca725eb5 100644
--- a/lldb/include/lldb/Protocol/MCP/Resource.h
+++ b/lldb/include/lldb/Protocol/MCP/Resource.h
@@ -20,7 +20,7 @@ class ResourceProvider {
virtual ~ResourceProvider() = default;
virtual std::vector<lldb_protocol::mcp::Resource> GetResources() const = 0;
- virtual llvm::Expected<lldb_protocol::mcp::ResourceResult>
+ virtual llvm::Expected<lldb_protocol::mcp::ResourcesReadResult>
ReadResource(llvm::StringRef uri) const = 0;
};
diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h
index 382f9a4731dd4..d749f8d493153 100644
--- a/lldb/include/lldb/Protocol/MCP/Server.h
+++ b/lldb/include/lldb/Protocol/MCP/Server.h
@@ -9,82 +9,52 @@
#ifndef LLDB_PROTOCOL_MCP_SERVER_H
#define LLDB_PROTOCOL_MCP_SERVER_H
-#include "lldb/Host/JSONTransport.h"
#include "lldb/Host/MainLoop.h"
+#include "lldb/Protocol/MCP/Binder.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/Protocol/MCP/Resource.h"
#include "lldb/Protocol/MCP/Tool.h"
+#include "lldb/Protocol/MCP/Transport.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Error.h"
+#include <memory>
#include <mutex>
-namespace lldb_protocol::mcp {
-
-class MCPTransport final
- : public lldb_private::JSONRPCTransport<Request, Response, Notification> {
-public:
- using LogCallback = std::function<void(llvm::StringRef message)>;
-
- MCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out,
- std::string client_name, LogCallback log_callback = {})
- : JSONRPCTransport(in, out), m_client_name(std::move(client_name)),
- m_log_callback(log_callback) {}
- virtual ~MCPTransport() = default;
+namespace lldb_private::mcp {
+class ProtocolServerMCP;
+} // namespace lldb_private::mcp
- void Log(llvm::StringRef message) override {
- if (m_log_callback)
- m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str());
- }
+namespace lldb_protocol::mcp {
-private:
- std::string m_client_name;
- LogCallback m_log_callback;
-};
+class Server {
+ friend class lldb_private::mcp::ProtocolServerMCP;
+ friend class lldb_private::mcp::ProtocolServerMCP;
-class Server : public MCPTransport::MessageHandler {
public:
Server(std::string name, std::string version,
std::unique_ptr<MCPTransport> transport_up,
lldb_private::MainLoop &loop);
~Server() = default;
- using NotificationHandler = std::function<void(const Notification &)>;
-
void AddTool(std::unique_ptr<Tool> tool);
void AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider);
- void AddNotificationHandler(llvm::StringRef method,
- NotificationHandler handler);
llvm::Error Run();
-protected:
- Capabilities GetCapabilities();
-
- using RequestHandler =
- std::function<llvm::Expected<Response>(const Request &)>;
-
- void AddRequestHandlers();
-
- void AddRequestHandler(llvm::StringRef method, RequestHandler handler);
+ Binder &GetBinder() { return m_binder; };
- llvm::Expected<std::optional<Message>> HandleData(llvm::StringRef data);
-
- llvm::Expected<Response> Handle(Request request);
- void Handle(Notification notification);
-
- llvm::Expected<Response> InitializeHandler(const Request &);
+protected:
+ ServerCapabilities GetCapabilities();
- llvm::Expected<Response> ToolsListHandler(const Request &);
- llvm::Expected<Response> ToolsCallHandler(const Request &);
+ llvm::Expected<InitializeResult>
+ InitializeHandler(const InitializeParams &request);
- llvm::Expected<Response> ResourcesListHandler(const Request &);
- llvm::Expected<Response> ResourcesReadHandler(const Request &);
+ llvm::Expected<ToolsListResult> ToolsListHandler(const Void &);
+ llvm::Expected<ToolsCallResult> ToolsCallHandler(const ToolsCallParams &);
- void Received(const Request &) override;
- void Received(const Response &) override;
- void Received(const Notification &) override;
- void OnError(llvm::Error) override;
- void OnClosed() override;
+ llvm::Expected<ResourcesListResult> ResourcesListHandler(const Void &);
+ llvm::Expected<ResourcesReadResult>
+ ResourcesReadHandler(const ResourcesReadParams &);
void TerminateLoop();
@@ -99,9 +69,7 @@ class Server : public MCPTransport::MessageHandler {
llvm::StringMap<std::unique_ptr<Tool>> m_tools;
std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers;
-
- llvm::StringMap<RequestHandler> m_request_handlers;
- llvm::StringMap<NotificationHandler> m_notification_handlers;
+ Binder m_binder;
};
} // namespace lldb_protocol::mcp
diff --git a/lldb/include/lldb/Protocol/MCP/Tool.h b/lldb/include/lldb/Protocol/MCP/Tool.h
index 96669d1357166..26cbc943f0704 100644
--- a/lldb/include/lldb/Protocol/MCP/Tool.h
+++ b/lldb/include/lldb/Protocol/MCP/Tool.h
@@ -10,6 +10,8 @@
#define LLDB_PROTOCOL_MCP_TOOL_H
#include "lldb/Protocol/MCP/Protocol.h"
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/JSON.h"
#include <string>
@@ -20,8 +22,11 @@ class Tool {
Tool(std::string name, std::string description);
virtual ~Tool() = default;
- virtual llvm::Expected<lldb_protocol::mcp::TextResult>
- Call(const lldb_protocol::mcp::ToolArguments &args) = 0;
+ using Reply = llvm::unique_function<void(
+ llvm::Expected<lldb_protocol::mcp::ToolsCallResult>)>;
+
+ virtual void Call(const lldb_protocol::mcp::ToolArguments &args,
+ Reply reply) = 0;
virtual std::optional<llvm::json::Value> GetSchema() const {
return llvm::json::Object{{"type", "object"}};
diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h
new file mode 100644
index 0000000000000..efbddc6d31d17
--- /dev/null
+++ b/lldb/include/lldb/Protocol/MCP/Transport.h
@@ -0,0 +1,50 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLDB_PROTOCOL_MCP_TRANSPORT_H
+#define LLDB_PROTOCOL_MCP_TRANSPORT_H
+
+#include "lldb/Host/JSONTransport.h"
+#include "lldb/Host/Socket.h"
+#include "lldb/Protocol/MCP/Protocol.h"
+#include "lldb/lldb-forward.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+
+namespace lldb_protocol::mcp {
+
+using MCPTransport = lldb_private::Transport<Request, Response, Notification>;
+using MCPTransportUP = std::unique_ptr<MCPTransport>;
+
+llvm::StringRef CommunicationSocketPath();
+llvm::Expected<lldb::IOObjectSP> Connect();
+
+class Transport final
+ : public lldb_private::JSONRPCTransport<Request, Response, Notification> {
+public:
+ using LogCallback = std::function<void(llvm::StringRef message)>;
+
+ Transport(lldb::IOObjectSP input, lldb::IOObjectSP output,
+ std::string client_name = "", LogCallback log_callback = {});
+
+ void Log(llvm::StringRef message) override;
+
+ static llvm::Expected<MCPTransportUP>
+ Connect(llvm::raw_ostream *logger = nullptr);
+
+private:
+ std::string m_client_name;
+ LogCallback m_log_callback;
+};
+using TransportUP = std::unique_ptr<Transport>;
+
+} // namespace lldb_protocol::mcp
+
+#endif
diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
index 57132534cf680..15558b4e7c914 100644
--- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
+++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
@@ -10,14 +10,10 @@
#include "Resource.h"
#include "Tool.h"
#include "lldb/Core/PluginManager.h"
-#include "lldb/Protocol/MCP/MCPError.h"
-#include "lldb/Protocol/MCP/Tool.h"
#include "lldb/Utility/LLDBLog.h"
#include "lldb/Utility/Log.h"
-#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Threading.h"
#include <thread>
-#include <variant>
using namespace lldb_private;
using namespace lldb_private::mcp;
@@ -50,12 +46,14 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
return "MCP Server.";
}
-void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
- server.AddNotificationHandler("notifications/initialized",
- [](const lldb_protocol::mcp::Notification &) {
- LLDB_LOG(GetLog(LLDBLog::Host),
- "MCP initialization complete");
- });
+void ProtocolServerMCP::OnInitialized(
+ const lldb_protocol::mcp::Notification &) {
+ LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
+}
+
+void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) {
+ server.m_binder.notification("notifications/initialized",
+ &ProtocolServerMCP::OnInitialized, this);
server.AddTool(
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
server.AddResourceProvider(std::make_unique<DebuggerResourceProvider>());
@@ -67,7 +65,7 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
LLDB_LOG(log, "New MCP client connected: {0}", client_name);
lldb::IOObjectSP io_sp = std::move(socket);
- auto transport_up = std::make_unique<lldb_protocol::mcp::MCPTransport>(
+ auto transport_up = std::make_unique<lldb_protocol::mcp::Transport>(
io_sp, io_sp, std::move(client_name), [&](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message);
});
diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
index fc650ffe0dfa7..d35f7f678b2c4 100644
--- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
+++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
@@ -41,7 +41,9 @@ class ProtocolServerMCP : public ProtocolServer {
protected:
// This adds tools and resource providers that
// are specific to this server. Overridable by the unit tests.
- virtual void Extend(lldb_protocol::mcp::Server &server) const;
+ virtual void Extend(lldb_protocol::mcp::Server &server);
+
+ void OnInitialized(const lldb_protocol::mcp::Notification &);
private:
void AcceptCallback(std::unique_ptr<Socket> socket);
diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp
index e94d2cdd65e07..b5f0a6569654b 100644
--- a/lldb/source/Plugins/Protocol/MCP/Resource.cpp
+++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp
@@ -124,7 +124,7 @@ DebuggerResourceProvider::GetResources() const {
return resources;
}
-llvm::Expected<lldb_protocol::mcp::ResourceResult>
+llvm::Expected<lldb_protocol::mcp::ResourcesReadResult>
DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const {
auto [protocol, path] = uri.split("://");
@@ -161,7 +161,7 @@ DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const {
return ReadDebuggerResource(uri, debugger_idx);
}
-llvm::Expected<lldb_protocol::mcp::ResourceResult>
+llvm::Expected<lldb_protocol::mcp::ResourcesReadResult>
DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri,
lldb::user_id_t debugger_id) {
lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id);
@@ -178,12 +178,12 @@ DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri,
contents.mimeType = kMimeTypeJSON;
contents.text = llvm::formatv("{0}", toJSON(debugger_resource));
- lldb_protocol::mcp::ResourceResult result;
+ lldb_protocol::mcp::ResourcesReadResult result;
result.contents.push_back(contents);
return result;
}
-llvm::Expected<lldb_protocol::mcp::ResourceResult>
+llvm::Expected<lldb_protocol::mcp::ResourcesReadResult>
DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri,
lldb::user_id_t debugger_id,
size_t target_idx) {
@@ -214,7 +214,7 @@ DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri,
contents.mimeType = kMimeTypeJSON;
contents.text = llvm::formatv("{0}", toJSON(target_resource));
- lldb_protocol::mcp::ResourceResult result;
+ lldb_protocol::mcp::ResourcesReadResult result;
result.contents.push_back(contents);
return result;
}
diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h
index e2382a74f796b..0810f1fb0c4f4 100644
--- a/lldb/source/Plugins/Protocol/MCP/Resource.h
+++ b/lldb/source/Plugins/Protocol/MCP/Resource.h
@@ -23,7 +23,7 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider {
virtual std::vector<lldb_protocol::mcp::Resource>
GetResources() const override;
- virtual llvm::Expected<lldb_protocol::mcp::ResourceResult>
+ virtual llvm::Expected<lldb_protocol::mcp::ResourcesReadResult>
ReadResource(llvm::StringRef uri) const override;
private:
@@ -31,9 +31,9 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider {
static lldb_protocol::mcp::Resource GetTargetResource(size_t target_idx,
Target &target);
- static llvm::Expected<lldb_protocol::mcp::ResourceResult>
+ static llvm::Expected<lldb_protocol::mcp::ResourcesReadResult>
ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id);
- static llvm::Expected<lldb_protocol::mcp::ResourceResult>
+ static llvm::Expected<lldb_protocol::mcp::ResourcesReadResult>
ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id,
size_t target_idx);
};
diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp
index 143470702a6fd..dabf100874b62 100644
--- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp
+++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp
@@ -14,6 +14,7 @@
using namespace lldb_private;
using namespace lldb_protocol;
using namespace lldb_private::mcp;
+using namespace lldb_protocol::mcp;
using namespace llvm;
namespace {
@@ -29,10 +30,10 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A,
O.mapOptional("arguments", A.arguments);
}
-/// Helper function to create a TextResult from a string output.
-static lldb_protocol::mcp::TextResult createTextResult(std::string output,
- bool is_error = false) {
- lldb_protocol::mcp::TextResult text_result;
+/// Helper function to create a ToolsCallResult from a string output.
+static lldb_protocol::mcp::ToolsCallResult
+createTextResult(std::string output, bool is_error = false) {
+ lldb_protocol::mcp::ToolsCallResult text_result;
text_result.content.emplace_back(
lldb_protocol::mcp::TextContent{{std::move(output)}});
text_result.isError = is_error;
@@ -41,22 +42,23 @@ static lldb_protocol::mcp::TextResult createTextResult(std::string output,
} // namespace
-llvm::Expected<lldb_protocol::mcp::TextResult>
-CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) {
+namespace lldb_private::mcp {
+
+void CommandTool::Call(const ToolArguments &args, Reply reply) {
if (!std::holds_alternative<json::Value>(args))
- return createStringError("CommandTool requires arguments");
+ return reply(createStringError("CommandTool requires arguments"));
json::Path::Root root;
CommandToolArguments arguments;
if (!fromJSON(std::get<json::Value>(args), arguments, root))
- return root.getError();
+ return reply(root.getError());
lldb::DebuggerSP debugger_sp =
Debugger::FindDebuggerWithID(arguments.debugger_id);
if (!debugger_sp)
- return createStringError(
- llvm::formatv("no debugger with id {0}", arguments.debugger_id));
+ return reply(createStringError(
+ llvm::formatv("no debugger with id {0}", arguments.debugger_id)));
// FIXME: Disallow certain commands and their aliases.
CommandReturnObject result(/*colors=*/false);
@@ -75,7 +77,7 @@ CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) {
output += err_str;
}
- return createTextResult(output, !result.Succeeded());
+ reply(createTextResult(output, !result.Succeeded()));
}
std::optional<llvm::json::Value> CommandTool::GetSchema() const {
@@ -89,3 +91,5 @@ std::optional<llvm::json::Value> CommandTool::GetSchema() const {
{"required", std::move(required)}};
return schema;
}
+
+} // namespace lldb_private::mcp
diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h
index b7b1756eb38d7..4fc5884174e01 100644
--- a/lldb/source/Plugins/Protocol/MCP/Tool.h
+++ b/lldb/source/Plugins/Protocol/MCP/Tool.h
@@ -22,10 +22,11 @@ class CommandTool : public lldb_protocol::mcp::Tool {
using lldb_protocol::mcp::Tool::Tool;
~CommandTool() = default;
- virtual llvm::Expected<lldb_protocol::mcp::TextResult>
- Call(const lldb_protocol::mcp::ToolArguments &args) override;
+ void Call(const lldb_protocol::mcp::ToolArguments &,
+ llvm::unique_function<void(
+ llvm::Expected<lldb_protocol::mcp::ToolsCallResult>)>) override;
- virtual std::optional<llvm::json::Value> GetSchema() const override;
+ std::optional<llvm::json::Value> GetSchema() const override;
};
} // namespace lldb_private::mcp
diff --git a/lldb/source/Protocol/MCP/Binder.cpp b/lldb/source/Protocol/MCP/Binder.cpp
new file mode 100644
index 0000000000000..90ae39ba0e3f0
--- /dev/null
+++ b/lldb/source/Protocol/MCP/Binder.cpp
@@ -0,0 +1,139 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "lldb/Protocol/MCP/Binder.h"
+#include "lldb/Protocol/MCP/Protocol.h"
+#include "lldb/Protocol/MCP/Transport.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <atomic>
+#include <cassert>
+#include <mutex>
+
+using namespace llvm;
+
+namespace lldb_protocol::mcp {
+
+/// Function object to reply to a call.
+/// Each instance must be called exactly once, otherwise:
+/// - the bug is logged, and (in debug mode) an assert will fire
+/// - if there was no reply, an error reply is sent
+/// - if there were multiple replies, only the first is sent
+class ReplyOnce {
+ std::atomic<bool> replied = {false};
+ const Id id;
+ MCPTransport *transport; // Null when moved-from.
+ MCPTransport::MessageHandler *handler; // Null when moved-from.
+
+public:
+ ReplyOnce(const Id id, MCPTransport *transport,
+ MCPTransport::MessageHandler *handler)
+ : id(id), transport(transport), handler(handler) {
+ assert(handler);
+ }
+ ReplyOnce(ReplyOnce &&other)
+ : replied(other.replied.load()), id(other.id), transport(other.transport),
+ handler(other.handler) {
+ other.transport = nullptr;
+ other.handler = nullptr;
+ }
+ ReplyOnce &operator=(ReplyOnce &&) = delete;
+ ReplyOnce(const ReplyOnce &) = delete;
+ ReplyOnce &operator=(const ReplyOnce &) = delete;
+
+ ~ReplyOnce() {
+ if (transport && handler && !replied) {
+ assert(false && "must reply to all calls!");
+ (*this)(Response{id, Error{MCPError::kInternalError, "failed to reply",
+ std::nullopt}});
+ }
+ }
+
+ void operator()(const Response &resp) {
+ assert(transport && handler && "moved-from!");
+ if (replied.exchange(true)) {
+ assert(false && "must reply to each call only once!");
+ return;
+ }
+
+ if (llvm::Error error = transport->Send(Response{id, resp.result}))
+ handler->OnError(std::move(error));
+ }
+};
+
+Binder::RawHandler::RawHandler(MCPTransport *transport)
+ : m_transport(transport) {}
+
+void Binder::RawHandler::Received(const Notification ¬e) {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ auto it = m_notification_handlers.find(note.method);
+ if (it == m_notification_handlers.end()) {
+ OnError(llvm::createStringError(
+ formatv("no handled for notification {0}", toJSON(note))));
+ return;
+ }
+ it->second(note);
+}
+
+void Binder::RawHandler::Received(const Request &req) {
+ ReplyOnce reply(req.id, m_transport, this);
+
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ auto it = m_request_handlers.find(req.method);
+ if (it == m_request_handlers.end()) {
+ reply({req.id,
+ Error{eErrorCodeMethodNotFound, "method not found", std::nullopt}});
+ return;
+ }
+
+ it->second(req, std::move(reply));
+}
+
+void Binder::RawHandler::Received(const Response &resp) {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ auto it = m_pending_responses.find(resp.id);
+ if (it == m_pending_responses.end()) {
+ OnError(llvm::createStringError(
+ formatv("no pending request for {0}", toJSON(resp))));
+ return;
+ }
+
+ it->second(resp);
+ m_pending_responses.erase(it);
+}
+
+void Binder::RawHandler::OnError(llvm::Error err) {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ if (m_error_handler)
+ m_error_handler(m_transport, std::move(err));
+}
+
+void Binder::RawHandler::OnClosed() {
+ std::scoped_lock<std::recursive_mutex> guard(m_mutex);
+ if (m_disconnect_handler)
+ m_disconnect_handler(m_transport);
+}
+
+void Binder::RawHandler::Send(
+ const Request &req,
+ llvm::unique_function<void(const Response &)> response_handler) {
+ std::lock_guard<std::recursive_mutex> guard(m_mutex);
+ Id id = ++m_seq;
+ if (llvm::Error err = m_transport->Send(Request{id, req.method, req.params}))
+ return OnError(std::move(err));
+ m_pending_responses[id] = std::move(response_handler);
+}
+
+void Binder::RawHandler::Send(const Notification ¬e) {
+ std::lock_guard<std::recursive_mutex> guard(m_mutex);
+ if (llvm::Error err = m_transport->Send(note))
+ return OnError(std::move(err));
+}
+
+} // namespace lldb_protocol::mcp
diff --git a/lldb/source/Protocol/MCP/CMakeLists.txt b/lldb/source/Protocol/MCP/CMakeLists.txt
index a73e7e6a7cab1..e6e8200833ffd 100644
--- a/lldb/source/Protocol/MCP/CMakeLists.txt
+++ b/lldb/source/Protocol/MCP/CMakeLists.txt
@@ -1,12 +1,15 @@
add_lldb_library(lldbProtocolMCP NO_PLUGIN_DEPENDENCIES
+ Binder.cpp
MCPError.cpp
Protocol.cpp
Server.cpp
Tool.cpp
+ Transport.cpp
LINK_COMPONENTS
Support
LINK_LIBS
lldbUtility
+ lldbHost
)
diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp
index 65ddfaee70160..8a976bb797d32 100644
--- a/lldb/source/Protocol/MCP/Protocol.cpp
+++ b/lldb/source/Protocol/MCP/Protocol.cpp
@@ -228,11 +228,11 @@ bool fromJSON(const llvm::json::Value &V, ResourceContents &RC,
O.mapOptional("mimeType", RC.mimeType);
}
-llvm::json::Value toJSON(const ResourceResult &RR) {
+llvm::json::Value toJSON(const ResourcesReadResult &RR) {
return llvm::json::Object{{"contents", RR.contents}};
}
-bool fromJSON(const llvm::json::Value &V, ResourceResult &RR,
+bool fromJSON(const llvm::json::Value &V, ResourcesReadResult &RR,
llvm::json::Path P) {
llvm::json::ObjectMapper O(V, P);
return O && O.map("contents", RR.contents);
@@ -325,4 +325,159 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) {
return false;
}
+json::Value toJSON(const Implementation &I) {
+ json::Object result{{"name", I.name}, {"version", I.version}};
+
+ if (!I.title.empty())
+ result.insert({"title", I.title});
+
+ return result;
+}
+
+bool fromJSON(const json::Value &V, Implementation &I, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("name", I.name) && O.mapOptional("title", I.title) &&
+ O.mapOptional("version", I.version);
+}
+
+json::Value toJSON(const ClientCapabilities &C) { return json::Object{}; }
+
+bool fromJSON(const json::Value &, ClientCapabilities &, json::Path) {
+ return true;
+}
+
+json::Value toJSON(const ServerCapabilities &C) {
+ json::Object result{};
+
+ if (C.supportsToolsList)
+ result.insert({"tools", json::Object{{"listChanged", true}}});
+
+ if (C.supportsResourcesList || C.supportsResourcesSubscribe) {
+ json::Object resources;
+ if (C.supportsResourcesList)
+ resources.insert({"listChanged", true});
+ if (C.supportsResourcesSubscribe)
+ resources.insert({"subscribe", true});
+ result.insert({"resources", std::move(resources)});
+ }
+
+ if (C.supportsCompletions)
+ result.insert({"completions", json::Object{}});
+
+ if (C.supportsLogging)
+ result.insert({"logging", json::Object{}});
+
+ return result;
+}
+
+bool fromJSON(const json::Value &V, ServerCapabilities &C, json::Path P) {
+ const json::Object *O = V.getAsObject();
+ if (!O) {
+ P.report("expected object");
+ return false;
+ }
+
+ if (O->find("tools") != O->end())
+ C.supportsToolsList = true;
+
+ return true;
+}
+
+json::Value toJSON(const InitializeParams &P) {
+ return json::Object{
+ {"protocolVersion", P.protocolVersion},
+ {"capabilities", P.capabilities},
+ {"clientInfo", P.clientInfo},
+ };
+}
+
+bool fromJSON(const json::Value &V, InitializeParams &I, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("protocolVersion", I.protocolVersion) &&
+ O.map("capabilities", I.capabilities) &&
+ O.map("clientInfo", I.clientInfo);
+}
+
+json::Value toJSON(const InitializeResult &R) {
+ json::Object result{{"protocolVersion", R.protocolVersion},
+ {"capabilities", R.capabilities},
+ {"serverInfo", R.serverInfo}};
+
+ if (!R.instructions.empty())
+ result.insert({"instructions", R.instructions});
+
+ return result;
+}
+
+bool fromJSON(const json::Value &V, InitializeResult &R, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("protocolVersion", R.protocolVersion) &&
+ O.map("capabilities", R.capabilities) &&
+ O.map("serverInfo", R.serverInfo) &&
+ O.mapOptional("instructions", R.instructions);
+}
+
+json::Value toJSON(const ToolsListResult &R) {
+ return json::Object{{"tools", R.tools}};
+}
+
+bool fromJSON(const json::Value &V, ToolsListResult &R, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("tools", R.tools);
+}
+
+json::Value toJSON(const ToolsCallResult &R) {
+ json::Object result{{"content", R.content}};
+
+ if (R.isError)
+ result.insert({"isError", R.isError});
+ if (R.structuredContent)
+ result.insert({"structuredContent", *R.structuredContent});
+
+ return result;
+}
+
+bool fromJSON(const json::Value &V, ToolsCallResult &R, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("content", R.content) &&
+ O.mapOptional("isError", R.isError) &&
+ mapRaw(V, "structuredContent", R.structuredContent, P);
+}
+
+json::Value toJSON(const ToolsCallParams &R) {
+ json::Object result{{"name", R.name}};
+
+ if (R.arguments)
+ result.insert({"arguments", *R.arguments});
+
+ return result;
+}
+
+bool fromJSON(const json::Value &V, ToolsCallParams &R, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("name", R.name) && mapRaw(V, "arguments", R.arguments, P);
+}
+
+json::Value toJSON(const ResourcesReadParams &R) {
+ return json::Object{{"uri", R.URI}};
+}
+
+bool fromJSON(const json::Value &V, ResourcesReadParams &R, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("uri", R.URI);
+}
+
+json::Value toJSON(const ResourcesListResult &R) {
+ return json::Object{{"resources", R.resources}};
+}
+
+bool fromJSON(const json::Value &V, ResourcesListResult &R, json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("resources", R.resources);
+}
+
+json::Value toJSON(const Void &R) { return json::Object{}; }
+
+bool fromJSON(const json::Value &V, Void &R, json::Path P) { return true; }
+
} // namespace lldb_protocol::mcp
diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp
index 3713e8e46c5d6..a612967d5fa51 100644
--- a/lldb/source/Protocol/MCP/Server.cpp
+++ b/lldb/source/Protocol/MCP/Server.cpp
@@ -7,8 +7,16 @@
//===----------------------------------------------------------------------===//
#include "lldb/Protocol/MCP/Server.h"
+#include "lldb/Host/Socket.h"
+#include "lldb/Protocol/MCP/Binder.h"
#include "lldb/Protocol/MCP/MCPError.h"
+#include "lldb/Protocol/MCP/Protocol.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/Threading.h"
+#include <future>
+#include <memory>
+using namespace lldb_private;
using namespace lldb_protocol::mcp;
using namespace llvm;
@@ -16,83 +24,13 @@ Server::Server(std::string name, std::string version,
std::unique_ptr<MCPTransport> transport_up,
lldb_private::MainLoop &loop)
: m_name(std::move(name)), m_version(std::move(version)),
- m_transport_up(std::move(transport_up)), m_loop(loop) {
- AddRequestHandlers();
-}
-
-void Server::AddRequestHandlers() {
- AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this,
- std::placeholders::_1));
- AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler,
- this, std::placeholders::_1));
- AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler,
- this, std::placeholders::_1));
-}
-
-llvm::Expected<Response> Server::Handle(Request request) {
- auto it = m_request_handlers.find(request.method);
- if (it != m_request_handlers.end()) {
- llvm::Expected<Response> response = it->second(request);
- if (!response)
- return response;
- response->id = request.id;
- return *response;
- }
-
- return llvm::make_error<MCPError>(
- llvm::formatv("no handler for request: {0}", request.method).str());
-}
-
-void Server::Handle(Notification notification) {
- auto it = m_notification_handlers.find(notification.method);
- if (it != m_notification_handlers.end()) {
- it->second(notification);
- return;
- }
-}
-
-llvm::Expected<std::optional<Message>>
-Server::HandleData(llvm::StringRef data) {
- auto message = llvm::json::parse<Message>(/*JSON=*/data);
- if (!message)
- return message.takeError();
-
- if (const Request *request = std::get_if<Request>(&(*message))) {
- llvm::Expected<Response> response = Handle(*request);
-
- // Handle failures by converting them into an Error message.
- if (!response) {
- Error protocol_error;
- llvm::handleAllErrors(
- response.takeError(),
- [&](const MCPError &err) { protocol_error = err.toProtocolError(); },
- [&](const llvm::ErrorInfoBase &err) {
- protocol_error.code = MCPError::kInternalError;
- protocol_error.message = err.message();
- });
- Response error_response;
- error_response.id = request->id;
- error_response.result = std::move(protocol_error);
- return error_response;
- }
-
- return *response;
- }
-
- if (const Notification *notification =
- std::get_if<Notification>(&(*message))) {
- Handle(*notification);
- return std::nullopt;
- }
-
- if (std::get_if<Response>(&(*message)))
- return llvm::createStringError("unexpected MCP message: response");
-
- llvm_unreachable("all message types handled");
+ m_transport_up(std::move(transport_up)), m_loop(loop),
+ m_binder(m_transport_up.get()) {
+ m_binder.request("initialize", &Server::InitializeHandler, this);
+ m_binder.request("tools/list", &Server::ToolsListHandler, this);
+ m_binder.request("tools/call", &Server::ToolsCallHandler, this);
+ m_binder.request("resources/list", &Server::ResourcesListHandler, this);
+ m_binder.request("resources/read", &Server::ResourcesReadHandler, this);
}
void Server::AddTool(std::unique_ptr<Tool> tool) {
@@ -112,54 +50,30 @@ void Server::AddResourceProvider(
m_resource_providers.push_back(std::move(resource_provider));
}
-void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) {
- std::lock_guard<std::mutex> guard(m_mutex);
- m_request_handlers[method] = std::move(handler);
-}
-
-void Server::AddNotificationHandler(llvm::StringRef method,
- NotificationHandler handler) {
- std::lock_guard<std::mutex> guard(m_mutex);
- m_notification_handlers[method] = std::move(handler);
-}
-
-llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
- Response response;
- response.result = llvm::json::Object{
- {"protocolVersion", mcp::kProtocolVersion},
- {"capabilities", GetCapabilities()},
- {"serverInfo",
- llvm::json::Object{{"name", m_name}, {"version", m_version}}}};
- return response;
+Expected<InitializeResult>
+Server::InitializeHandler(const InitializeParams &request) {
+ InitializeResult result;
+ result.protocolVersion = mcp::kProtocolVersion;
+ result.capabilities = GetCapabilities();
+ result.serverInfo = Implementation{m_name, "", m_version};
+ return result;
}
-llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
- Response response;
+llvm::Expected<ToolsListResult> Server::ToolsListHandler(const Void &) {
+ ToolsListResult result;
- llvm::json::Array tools;
for (const auto &tool : m_tools)
- tools.emplace_back(toJSON(tool.second->GetDefinition()));
+ result.tools.emplace_back(tool.second->GetDefinition());
- response.result = llvm::json::Object{{"tools", std::move(tools)}};
-
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no tool parameters");
-
- const json::Object *param_obj = request.params->getAsObject();
- if (!param_obj)
- return llvm::createStringError("no tool parameters");
-
- const json::Value *name = param_obj->get("name");
- if (!name)
+llvm::Expected<ToolsCallResult>
+Server::ToolsCallHandler(const ToolsCallParams ¶ms) {
+ if (params.name.empty())
return llvm::createStringError("no tool name");
- llvm::StringRef tool_name = name->getAsString().value_or("");
+ llvm::StringRef tool_name = params.name;
if (tool_name.empty())
return llvm::createStringError("no tool name");
@@ -168,56 +82,41 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));
ToolArguments tool_args;
- if (const json::Value *args = param_obj->get("arguments"))
- tool_args = *args;
+ if (params.arguments)
+ tool_args = *params.arguments;
- llvm::Expected<TextResult> text_result = it->second->Call(tool_args);
- if (!text_result)
- return text_result.takeError();
-
- response.result = toJSON(*text_result);
-
- return response;
+ std::promise<llvm::Expected<ToolsCallResult>> result_promise;
+ it->second->Call(tool_args,
+ [&result_promise](llvm::Expected<ToolsCallResult> result) {
+ result_promise.set_value(std::move(result));
+ });
+ return result_promise.get_future().get();
}
-llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
- Response response;
-
- llvm::json::Array resources;
+llvm::Expected<ResourcesListResult> Server::ResourcesListHandler(const Void &) {
+ ResourcesListResult result;
std::lock_guard<std::mutex> guard(m_mutex);
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
- m_resource_providers) {
+ m_resource_providers)
for (const Resource &resource : resource_provider_up->GetResources())
- resources.push_back(resource);
- }
- response.result = llvm::json::Object{{"resources", std::move(resources)}};
+ result.resources.push_back(resource);
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no resource parameters");
-
- const json::Object *param_obj = request.params->getAsObject();
- if (!param_obj)
- return llvm::createStringError("no resource parameters");
-
- const json::Value *uri = param_obj->get("uri");
- if (!uri)
- return llvm::createStringError("no resource uri");
+llvm::Expected<ResourcesReadResult>
+Server::ResourcesReadHandler(const ResourcesReadParams ¶ms) {
+ ResourcesReadResult result;
- llvm::StringRef uri_str = uri->getAsString().value_or("");
+ llvm::StringRef uri_str = params.URI;
if (uri_str.empty())
return llvm::createStringError("no resource uri");
std::lock_guard<std::mutex> guard(m_mutex);
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers) {
- llvm::Expected<ResourceResult> result =
+ llvm::Expected<ResourcesReadResult> result =
resource_provider_up->ReadResource(uri_str);
if (result.errorIsA<UnsupportedURI>()) {
llvm::consumeError(result.takeError());
@@ -225,10 +124,7 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
}
if (!result)
return result.takeError();
-
- Response response;
- response.result = std::move(*result);
- return response;
+ return *result;
}
return make_error<MCPError>(
@@ -236,17 +132,18 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
MCPError::kResourceNotFound);
}
-Capabilities Server::GetCapabilities() {
- lldb_protocol::mcp::Capabilities capabilities;
- capabilities.tools.listChanged = true;
+ServerCapabilities Server::GetCapabilities() {
+ ServerCapabilities capabilities;
+ capabilities.supportsToolsList = true;
+ capabilities.supportsResourcesList = true;
// FIXME: Support sending notifications when a debugger/target are
// added/removed.
- capabilities.resources.listChanged = false;
+ // capabilities.supportsResourcesSubscribe = true;
return capabilities;
}
llvm::Error Server::Run() {
- auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this);
+ auto handle = m_transport_up->RegisterMessageHandler(m_loop, m_binder);
if (!handle)
return handle.takeError();
@@ -257,48 +154,6 @@ llvm::Error Server::Run() {
return llvm::Error::success();
}
-void Server::Received(const Request &request) {
- auto SendResponse = [this](const Response &response) {
- if (llvm::Error error = m_transport_up->Send(response))
- m_transport_up->Log(llvm::toString(std::move(error)));
- };
-
- llvm::Expected<Response> response = Handle(request);
- if (response)
- return SendResponse(*response);
-
- lldb_protocol::mcp::Error protocol_error;
- llvm::handleAllErrors(
- response.takeError(),
- [&](const MCPError &err) { protocol_error = err.toProtocolError(); },
- [&](const llvm::ErrorInfoBase &err) {
- protocol_error.code = MCPError::kInternalError;
- protocol_error.message = err.message();
- });
- Response error_response;
- error_response.id = request.id;
- error_response.result = std::move(protocol_error);
- SendResponse(error_response);
-}
-
-void Server::Received(const Response &response) {
- m_transport_up->Log("unexpected MCP message: response");
-}
-
-void Server::Received(const Notification ¬ification) {
- Handle(notification);
-}
-
-void Server::OnError(llvm::Error error) {
- m_transport_up->Log(llvm::toString(std::move(error)));
- TerminateLoop();
-}
-
-void Server::OnClosed() {
- m_transport_up->Log("EOF");
- TerminateLoop();
-}
-
void Server::TerminateLoop() {
m_loop.AddPendingCallback(
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
diff --git a/lldb/source/Protocol/MCP/Transport.cpp b/lldb/source/Protocol/MCP/Transport.cpp
new file mode 100644
index 0000000000000..28cf754aef3e8
--- /dev/null
+++ b/lldb/source/Protocol/MCP/Transport.cpp
@@ -0,0 +1,113 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "lldb/Protocol/MCP/Transport.h"
+#include "lldb/Host/FileSystem.h"
+#include "lldb/Host/HostInfo.h"
+#include "lldb/Host/Socket.h"
+#include "lldb/Utility/FileSpec.h"
+#include "lldb/lldb-forward.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/Threading.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+#include <thread>
+
+using namespace llvm;
+using namespace lldb;
+using namespace lldb_private;
+
+namespace lldb_protocol::mcp {
+
+static Expected<sys::ProcessInfo> StartServer() {
+ static once_flag f;
+ static FileSpec candidate;
+ llvm::call_once(f, [] {
+ HostInfo::Initialize();
+ candidate = HostInfo::GetSupportExeDir();
+ candidate.AppendPathComponent("lldb-mcp");
+ });
+
+ if (!FileSystem::Instance().Exists(candidate))
+ return createStringError("lldb-mcp executable not found");
+ std::vector<StringRef> args = {candidate.GetPath(), "--server"};
+ sys::ProcessInfo proc =
+ sys::ExecuteNoWait(candidate.GetPath(), args, std::nullopt, {}, 0,
+ nullptr, nullptr, nullptr, /*DetachProcess=*/true);
+ if (proc.Pid == sys::ProcessInfo::InvalidPid)
+ return createStringError("Failed to start server: " + candidate.GetPath());
+ StringRef socket_path = CommunicationSocketPath();
+ while (!sys::fs::exists(socket_path))
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ return proc;
+}
+
+Transport::Transport(lldb::IOObjectSP input, lldb::IOObjectSP output,
+ std::string client_name, LogCallback log_callback)
+ : JSONRPCTransport(input, output), m_client_name(client_name),
+ m_log_callback(log_callback) {}
+
+void Transport::Log(llvm::StringRef message) {
+ if (m_log_callback)
+ m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str());
+}
+
+llvm::StringRef CommunicationSocketPath() {
+ static std::once_flag f;
+ static SmallString<256> socket_path;
+ llvm::call_once(f, [] {
+ assert(sys::path::home_directory(socket_path) &&
+ "failed to get home directory");
+ sys::path::append(socket_path, ".lldb-mcp-sock");
+ });
+ return socket_path.str();
+}
+
+Expected<IOObjectSP> Connect() {
+ StringRef socket_path = CommunicationSocketPath();
+ if (!sys::fs::exists(socket_path))
+ if (llvm::Error err = StartServer().takeError())
+ return err;
+
+ Socket::SocketProtocol protocol = Socket::ProtocolUnixDomain;
+ Status error;
+ std::unique_ptr<Socket> socket = Socket::Create(protocol, error);
+ if (error.Fail())
+ return error.takeError();
+ std::chrono::steady_clock::time_point deadline =
+ std::chrono::steady_clock::now() + std::chrono::seconds(30);
+ while (std::chrono::steady_clock::now() < deadline) {
+ Status error = socket->Connect(socket_path);
+ if (error.Success()) {
+ return socket;
+ }
+ if (error.Fail() && error.GetError() != ECONNREFUSED &&
+ error.GetError() != ENOENT)
+ return error.takeError();
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+
+ return createStringError("failed to connect to lldb-mcp multiplexer");
+}
+
+Expected<MCPTransportUP> Transport::Connect(llvm::raw_ostream *logger) {
+ Expected<IOObjectSP> maybe_sock = lldb_protocol::mcp::Connect();
+ if (!maybe_sock)
+ return maybe_sock.takeError();
+
+ return std::make_unique<Transport>(*maybe_sock, *maybe_sock, "client",
+ [logger](StringRef msg) {
+ if (logger)
+ *logger << msg << "\n";
+ });
+}
+
+} // namespace lldb_protocol::mcp
diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp
index ea19922522ffe..45024b5ca9f3d 100644
--- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp
+++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp
@@ -277,10 +277,11 @@ TEST(ProtocolMCPTest, ResourceResult) {
contents2.text = "Second resource content";
contents2.mimeType = "application/json";
- ResourceResult result;
+ ResourcesReadResult result;
result.contents = {contents1, contents2};
- llvm::Expected<ResourceResult> deserialized_result = roundtripJSON(result);
+ llvm::Expected<ResourcesReadResult> deserialized_result =
+ roundtripJSON(result);
ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded());
ASSERT_EQ(result.contents.size(), deserialized_result->contents.size());
@@ -297,9 +298,10 @@ TEST(ProtocolMCPTest, ResourceResult) {
}
TEST(ProtocolMCPTest, ResourceResultEmpty) {
- ResourceResult result;
+ ResourcesReadResult result;
- llvm::Expected<ResourceResult> deserialized_result = roundtripJSON(result);
+ llvm::Expected<ResourcesReadResult> deserialized_result =
+ roundtripJSON(result);
ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded());
EXPECT_TRUE(deserialized_result->contents.empty());
diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp
index 83a42bfb6970c..91c47c2229320 100644
--- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp
+++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp
@@ -19,6 +19,7 @@
#include "lldb/Host/MainLoopBase.h"
#include "lldb/Host/Socket.h"
#include "lldb/Host/common/TCPSocket.h"
+#include "lldb/Protocol/MCP/Binder.h"
#include "lldb/Protocol/MCP/MCPError.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "llvm/Support/Error.h"
@@ -36,18 +37,34 @@ using namespace lldb_private;
using namespace lldb_protocol::mcp;
using testing::_;
+namespace lldb_protocol::mcp {
+void PrintTo(const Request &req, std::ostream *os) {
+ *os << formatv("{0}", toJSON(req)).str();
+}
+void PrintTo(const Response &resp, std::ostream *os) {
+ *os << formatv("{0}", toJSON(resp)).str();
+}
+void PrintTo(const Notification ¬e, std::ostream *os) {
+ *os << formatv("{0}", toJSON(note)).str();
+}
+void PrintTo(const Message &message, std::ostream *os) {
+ return std::visit([os](auto &&message) { return PrintTo(message, os); },
+ message);
+}
+} // namespace lldb_protocol::mcp
+
namespace {
class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP {
public:
using ProtocolServerMCP::GetSocket;
using ProtocolServerMCP::ProtocolServerMCP;
- using ExtendCallback =
- std::function<void(lldb_protocol::mcp::Server &server)>;
+ using ExtendCallback = std::function<void(
+ lldb_protocol::mcp::Server &server, lldb_protocol::mcp::Binder &binder)>;
- virtual void Extend(lldb_protocol::mcp::Server &server) const override {
+ void Extend(lldb_protocol::mcp::Server &server) override {
if (m_extend_callback)
- m_extend_callback(server);
+ m_extend_callback(server, server.GetBinder());
};
void Extend(ExtendCallback callback) { m_extend_callback = callback; }
@@ -55,7 +72,7 @@ class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP {
ExtendCallback m_extend_callback;
};
-using Message = typename Transport<Request, Response, Notification>::Message;
+using Message = typename lldb_protocol::mcp::Transport::Message;
class TestJSONTransport final
: public lldb_private::JSONRPCTransport<Request, Response, Notification> {
@@ -74,7 +91,8 @@ class TestTool : public Tool {
public:
using Tool::Tool;
- llvm::Expected<TextResult> Call(const ToolArguments &args) override {
+ void Call(const ToolArguments &args,
+ Callback<void(llvm::Expected<ToolsCallResult>)> reply) override {
std::string argument;
if (const json::Object *args_obj =
std::get<json::Value>(args).getAsObject()) {
@@ -83,9 +101,9 @@ class TestTool : public Tool {
}
}
- TextResult text_result;
+ ToolsCallResult text_result;
text_result.content.emplace_back(TextContent{{argument}});
- return text_result;
+ reply(text_result);
}
};
@@ -105,7 +123,7 @@ class TestResourceProvider : public ResourceProvider {
return resources;
}
- llvm::Expected<ResourceResult>
+ llvm::Expected<ResourcesReadResult>
ReadResource(llvm::StringRef uri) const override {
if (uri != "lldb://foo/bar")
return llvm::make_error<UnsupportedURI>(uri.str());
@@ -115,7 +133,7 @@ class TestResourceProvider : public ResourceProvider {
contents.mimeType = "application/json";
contents.text = "foobar";
- ResourceResult result;
+ ResourcesReadResult result;
result.contents.push_back(contents);
return result;
}
@@ -126,8 +144,9 @@ class ErrorTool : public Tool {
public:
using Tool::Tool;
- llvm::Expected<TextResult> Call(const ToolArguments &args) override {
- return llvm::createStringError("error");
+ void Call(const ToolArguments &args,
+ Callback<void(llvm::Expected<ToolsCallResult>)> reply) override {
+ reply(llvm::createStringError("error"));
}
};
@@ -136,11 +155,12 @@ class FailTool : public Tool {
public:
using Tool::Tool;
- llvm::Expected<TextResult> Call(const ToolArguments &args) override {
- TextResult text_result;
+ void Call(const ToolArguments &args,
+ Callback<void(llvm::Expected<ToolsCallResult>)> reply) override {
+ ToolsCallResult text_result;
text_result.content.emplace_back(TextContent{{"failed"}});
text_result.isError = true;
- return text_result;
+ reply(text_result);
}
};
@@ -191,7 +211,7 @@ class ProtocolServerMCPTest : public ::testing::Test {
connection.protocol = Socket::SocketProtocol::ProtocolTcp;
connection.name = llvm::formatv("{0}:0", k_localhost).str();
m_server_up = std::make_unique<TestProtocolServerMCP>();
- m_server_up->Extend([&](auto &server) {
+ m_server_up->Extend([&](auto &server, Binder &binder) {
server.AddTool(std::make_unique<TestTool>("test", "test tool"));
server.AddResourceProvider(std::make_unique<TestResourceProvider>());
});
@@ -225,7 +245,7 @@ TEST_F(ProtocolServerMCPTest, Initialization) {
llvm::StringLiteral request =
R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":1})json";
llvm::StringLiteral response =
- R"json({"id":1,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json";
+ R"json({"id":1,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":true},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json";
ASSERT_THAT_ERROR(Write(request), Succeeded());
llvm::Expected<Response> expected_resp = json::parse<Response>(response);
@@ -271,7 +291,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) {
llvm::StringLiteral request =
R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json";
llvm::StringLiteral response =
- R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json";
+ R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}]}})json";
ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
llvm::Expected<Response> expected_resp = json::parse<Response>(response);
@@ -281,7 +301,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) {
}
TEST_F(ProtocolServerMCPTest, ToolsCallError) {
- m_server_up->Extend([&](auto &server) {
+ m_server_up->Extend([&](auto &server, auto &binder) {
server.AddTool(std::make_unique<ErrorTool>("error", "error tool"));
});
@@ -298,7 +318,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) {
}
TEST_F(ProtocolServerMCPTest, ToolsCallFail) {
- m_server_up->Extend([&](auto &server) {
+ m_server_up->Extend([&](auto &server, auto &binder) {
server.AddTool(std::make_unique<FailTool>("fail", "fail tool"));
});
@@ -319,15 +339,15 @@ TEST_F(ProtocolServerMCPTest, NotificationInitialized) {
std::condition_variable cv;
std::mutex mutex;
- m_server_up->Extend([&](auto &server) {
- server.AddNotificationHandler("notifications/initialized",
- [&](const Notification ¬ification) {
- {
- std::lock_guard<std::mutex> lock(mutex);
- handler_called = true;
- }
- cv.notify_all();
- });
+ m_server_up->Extend([&](auto &server, auto &binder) {
+ binder.template notification<Void>(
+ "notifications/initialized", [&](const Void &) {
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ handler_called = true;
+ }
+ cv.notify_all();
+ });
});
llvm::StringLiteral request =
R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json";
More information about the lldb-commits
mailing list