[Mlir-commits] [mlir] [mlir-lsp] Support outgoing requests (PR #90078)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 25 09:15:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Brian Gesiak (modocache)
<details>
<summary>Changes</summary>
> This patch is based on top of #<!-- -->90076. For easier code review,
> please take a look only at the top commit in this pull request.
Add support for outgoing requests to `lsp::MessageHandler`. Much like
`MessageHandler::outgoingNotification`, this allows for the message
handler to send outgoing messages via its JSON transport, but in this
case, those messages are requests, not notifications.
Requests receive responses (also referred to as "replies" in
`MLIRLspServerSupportLib`). These were previously unsupported, and
`lsp::MessageHandler` would log an error each time it processed a JSON
message that appeared to be a response (something with an "id" field,
but no "method" field). However, the `outgoingRequest` method now
handles response callbacks: an outgoing request with a given ID is set
up such that a callback function is invoked when a response with that ID
is received.
---
Full diff: https://github.com/llvm/llvm-project/pull/90078.diff
6 Files Affected:
- (modified) mlir/include/mlir/Tools/lsp-server-support/Transport.h (+36-5)
- (modified) mlir/lib/Tools/lsp-server-support/Transport.cpp (+19-15)
- (modified) mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp (+1-1)
- (modified) mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp (+1-1)
- (modified) mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp (+1-1)
- (modified) mlir/unittests/Tools/lsp-server-support/Transport.cpp (+38)
``````````diff
diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index 44c71058cf717c..f353d9aa6cf688 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -95,10 +95,10 @@ class JSONTransport {
template <typename T>
using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
-/// An OutgoingNotification<T> is a function used for outgoing notifications
-/// send to the client.
+/// An OutgoingMessage<T> is a function used for outgoing requests o
+/// notifications to send to the client.
template <typename T>
-using OutgoingNotification = llvm::unique_function<void(const T &)>;
+using OutgoingMessage = llvm::unique_function<void(const T &)>;
/// A handler used to process the incoming transport messages.
class MessageHandler {
@@ -160,9 +160,10 @@ class MessageHandler {
};
}
- /// Create an OutgoingNotification object used for the given method.
+ /// Create an OutgoingMessage function that, when called, sends a notification
+ /// with the given method via the transport.
template <typename T>
- OutgoingNotification<T> outgoingNotification(llvm::StringLiteral method) {
+ OutgoingMessage<T> outgoingNotification(llvm::StringLiteral method) {
return [&, method](const T ¶ms) {
std::lock_guard<std::mutex> transportLock(transportOutputMutex);
Logger::info("--> {0}", method);
@@ -170,7 +171,29 @@ class MessageHandler {
};
}
+ /// Create an OutgoingMessage function that, when called, sends a request with
+ /// the given method and ID via the transport. Should the outgoing request be
+ /// met with a response, the response callback is invoked to handle that
+ /// response.
+ template <typename T>
+ OutgoingMessage<T> outgoingRequest(
+ llvm::StringLiteral method, llvm::json::Value id,
+ llvm::unique_function<void(llvm::Expected<llvm::json::Value>)> callback) {
+ responseHandlers.insert(
+ {getIDAsString(id), std::make_pair(method.str(), std::move(callback))});
+
+ return [&, method, id](const T ¶ms) {
+ std::lock_guard<std::mutex> transportLock(transportOutputMutex);
+ Logger::info("--> {0}", method);
+ transport.call(method, llvm::json::Value(params), id);
+ };
+ }
+
private:
+ /// Returns a string representation of a message ID, which is specified as
+ /// `integer | string | null`.
+ static std::string getIDAsString(llvm::json::Value id);
+
template <typename HandlerT>
using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
@@ -178,6 +201,14 @@ class MessageHandler {
HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
methodHandlers;
+ /// A pair of (1) the original request's method name, and (2) the callback
+ /// function to be invoked for responses.
+ using ResponseHandlerTy =
+ std::pair<std::string,
+ llvm::unique_function<void(llvm::Expected<llvm::json::Value>)>>;
+ /// A mapping from request/response ID to response handler.
+ llvm::StringMap<ResponseHandlerTy> responseHandlers;
+
JSONTransport &transport;
/// Mutex to guard sending output messages to the transport.
diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp
index 339c5f3825165d..e250e9e9b15e95 100644
--- a/mlir/lib/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp
@@ -117,24 +117,28 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
bool MessageHandler::onReply(llvm::json::Value id,
llvm::Expected<llvm::json::Value> result) {
- // TODO: Add support for reply callbacks when support for outgoing messages is
- // added. For now, we just log an error on any replies received.
- Callback<llvm::json::Value> replyHandler =
- [&id](llvm::Expected<llvm::json::Value> result) {
- Logger::error(
- "received a reply with ID {0}, but there was no such call", id);
- if (!result)
- llvm::consumeError(result.takeError());
- };
-
- // Log and run the reply handler.
- if (result)
- replyHandler(std::move(result));
- else
- replyHandler(result.takeError());
+ auto it = responseHandlers.find(getIDAsString(id));
+ if (it != responseHandlers.end()) {
+ Logger::info("--> reply:{0}({1})", it->second.first, id);
+ it->second.second(std::move(result));
+ } else {
+ Logger::error(
+ "received a reply with ID {0}, but there was no such outgoing request",
+ id);
+ if (!result)
+ llvm::consumeError(result.takeError());
+ }
return true;
}
+std::string MessageHandler::getIDAsString(llvm::json::Value id) {
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ os << id;
+ os.flush();
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// JSONTransport
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
index 0f23366f6fe80a..bd7f2a5dedc257 100644
--- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
@@ -91,7 +91,7 @@ struct LSPServer {
/// An outgoing notification used to send diagnostics to the client when they
/// are ready to be processed.
- OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
+ OutgoingMessage<PublishDiagnosticsParams> publishDiagnostics;
/// Used to indicate that the 'shutdown' request was received from the
/// Language Server client.
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
index f02372367e38c8..ffaa1c8d4de46f 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
@@ -104,7 +104,7 @@ struct LSPServer {
/// An outgoing notification used to send diagnostics to the client when they
/// are ready to be processed.
- OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
+ OutgoingMessage<PublishDiagnosticsParams> publishDiagnostics;
/// Used to indicate that the 'shutdown' request was received from the
/// Language Server client.
diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
index b62f68db9d60fa..bc312d18ea4037 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
@@ -72,7 +72,7 @@ struct LSPServer {
/// An outgoing notification used to send diagnostics to the client when they
/// are ready to be processed.
- OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
+ OutgoingMessage<PublishDiagnosticsParams> publishDiagnostics;
/// Used to indicate that the 'shutdown' request was received from the
/// Language Server client.
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index 48eae32a0fc3a4..a19b277b20d171 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -118,4 +118,42 @@ TEST_F(TransportInputTest, MethodNotFound) {
EXPECT_THAT(getOutput(), HasSubstr("\"error\""));
EXPECT_THAT(getOutput(), HasSubstr("\"message\":\"method not found: ack\""));
}
+
+TEST_F(TransportInputTest, OutgoingNotification) {
+ auto notifyFn = getMessageHandler().outgoingNotification<CompletionList>(
+ "outgoing-notification");
+ notifyFn(CompletionList{});
+ EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
+}
+
+TEST_F(TransportInputTest, ResponseHandlerNotFound) {
+ // Unhandled responses are only reported via error logging. As a result, this
+ // test can't make any expectations -- but it prints the output anyway, by way
+ // of demonstration.
+ Logger::setLogLevel(Logger::Level::Error);
+ writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"params\":null}\n");
+ runTransport();
+}
+
+TEST_F(TransportInputTest, OutgoingRequest) {
+ Logger::setLogLevel(Logger::Level::Debug);
+
+ // Make an outgoing request.
+ bool responseCallbackInvoked = false;
+ auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
+ "outgoing-request", 82,
+ [&responseCallbackInvoked](llvm::Expected<llvm::json::Value> value) {
+ ASSERT_TRUE((bool)value);
+ responseCallbackInvoked = true;
+ });
+ callFn(CompletionList{});
+ EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
+ EXPECT_FALSE(responseCallbackInvoked);
+
+ // The request receives a response. The message handler handles this response
+ // by invoking the callback from above.
+ writeInput("{\"jsonrpc\":\"2.0\",\"id\":82,\"params\":null}\n");
+ runTransport();
+ EXPECT_TRUE(responseCallbackInvoked);
+}
} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/90078
More information about the Mlir-commits
mailing list