[Mlir-commits] [mlir] [mlir-lsp] Support outgoing requests (PR #90078)

Brian Gesiak llvmlistbot at llvm.org
Thu Apr 25 09:15:16 PDT 2024


https://github.com/modocache created https://github.com/llvm/llvm-project/pull/90078

> This patch is based on top of #90076. For easier code review,
> please take a look only at the top commit in this pull request.

Add support for outgoing requests to `lsp::MessageHandler`. Much like
`MessageHandler::outgoingNotification`, this allows for the message
handler to send outgoing messages via its JSON transport, but in this
case, those messages are requests, not notifications.

Requests receive responses (also referred to as "replies" in
`MLIRLspServerSupportLib`). These were previously unsupported, and
`lsp::MessageHandler` would log an error each time it processed a JSON
message that appeared to be a response (something with an "id" field,
but no "method" field). However, the `outgoingRequest` method now
handles response callbacks: an outgoing request with a given ID is set
up such that a callback function is invoked when a response with that ID
is received.

>From 3061bae59059d81784ae31f1704c170afb3c83ad Mon Sep 17 00:00:00 2001
From: Brian Gesiak <brian at modocache.io>
Date: Thu, 25 Apr 2024 09:34:35 -0400
Subject: [PATCH 1/2] [mlir-lsp] Rename `OutgoingNotification`

Rename `OutgoingNotification` to `OutgoingMessage`, since the same
callback function type will be used in a future commit to represent
outgoing requests, in addition to outgoing notifications.

No functional change to behavior here, but an additional test is added
for outgoing notifications.
---
 .../include/mlir/Tools/lsp-server-support/Transport.h | 11 ++++++-----
 mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp          |  2 +-
 mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp     |  2 +-
 mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp        |  2 +-
 mlir/unittests/Tools/lsp-server-support/Transport.cpp |  7 +++++++
 5 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index 44c71058cf717c..c56e7219fff940 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -95,10 +95,10 @@ class JSONTransport {
 template <typename T>
 using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
 
-/// An OutgoingNotification<T> is a function used for outgoing notifications
-/// send to the client.
+/// An OutgoingMessage<T> is a function used for outgoing requests o
+/// notifications to send to the client.
 template <typename T>
-using OutgoingNotification = llvm::unique_function<void(const T &)>;
+using OutgoingMessage = llvm::unique_function<void(const T &)>;
 
 /// A handler used to process the incoming transport messages.
 class MessageHandler {
@@ -160,9 +160,10 @@ class MessageHandler {
     };
   }
 
-  /// Create an OutgoingNotification object used for the given method.
+  /// Create an OutgoingMessage function that, when called, sends a notification
+  /// with the given method via the transport.
   template <typename T>
-  OutgoingNotification<T> outgoingNotification(llvm::StringLiteral method) {
+  OutgoingMessage<T> outgoingNotification(llvm::StringLiteral method) {
     return [&, method](const T &params) {
       std::lock_guard<std::mutex> transportLock(transportOutputMutex);
       Logger::info("--> {0}", method);
diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
index 0f23366f6fe80a..bd7f2a5dedc257 100644
--- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
@@ -91,7 +91,7 @@ struct LSPServer {
 
   /// An outgoing notification used to send diagnostics to the client when they
   /// are ready to be processed.
-  OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
+  OutgoingMessage<PublishDiagnosticsParams> publishDiagnostics;
 
   /// Used to indicate that the 'shutdown' request was received from the
   /// Language Server client.
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
index f02372367e38c8..ffaa1c8d4de46f 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
@@ -104,7 +104,7 @@ struct LSPServer {
 
   /// An outgoing notification used to send diagnostics to the client when they
   /// are ready to be processed.
-  OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
+  OutgoingMessage<PublishDiagnosticsParams> publishDiagnostics;
 
   /// Used to indicate that the 'shutdown' request was received from the
   /// Language Server client.
diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
index b62f68db9d60fa..bc312d18ea4037 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
@@ -72,7 +72,7 @@ struct LSPServer {
 
   /// An outgoing notification used to send diagnostics to the client when they
   /// are ready to be processed.
-  OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
+  OutgoingMessage<PublishDiagnosticsParams> publishDiagnostics;
 
   /// Used to indicate that the 'shutdown' request was received from the
   /// Language Server client.
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index 48eae32a0fc3a4..b46f02bc4b197b 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -118,4 +118,11 @@ TEST_F(TransportInputTest, MethodNotFound) {
   EXPECT_THAT(getOutput(), HasSubstr("\"error\""));
   EXPECT_THAT(getOutput(), HasSubstr("\"message\":\"method not found: ack\""));
 }
+
+TEST_F(TransportInputTest, OutgoingNotification) {
+  auto notifyFn = getMessageHandler().outgoingNotification<CompletionList>(
+      "outgoing-notification");
+  notifyFn(CompletionList{});
+  EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
+}
 } // namespace

>From 782ba116c7dde92de015c2ab16081f0b9231a196 Mon Sep 17 00:00:00 2001
From: Brian Gesiak <brian at modocache.io>
Date: Thu, 25 Apr 2024 09:34:55 -0400
Subject: [PATCH 2/2] [mlir-lsp] Support outgoing requests

Add support for outgoing requests to `lsp::MessageHandler`. Much like
`MessageHandler::outgoingNotification`, this allows for the message
handler to send outgoing messages via its JSON transport, but in this
case, those messages are requests, not notifications.

Requests receive responses (also referred to as "replies" in
`MLIRLspServerSupportLib`). These were previously unsupported, and
`lsp::MessageHandler` would log an error each time it processed a JSON
message that appeared to be a response (something with an "id" field,
but no "method" field). However, the `outgoingRequest` method now
handles response callbacks: an outgoing request with a given ID is set
up such that a callback function is invoked when a response with that ID
is received.
---
 .../mlir/Tools/lsp-server-support/Transport.h | 30 ++++++++++++++++
 .../Tools/lsp-server-support/Transport.cpp    | 34 +++++++++++--------
 .../Tools/lsp-server-support/Transport.cpp    | 31 +++++++++++++++++
 3 files changed, 80 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index c56e7219fff940..f353d9aa6cf688 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -171,7 +171,29 @@ class MessageHandler {
     };
   }
 
+  /// Create an OutgoingMessage function that, when called, sends a request with
+  /// the given method and ID via the transport. Should the outgoing request be
+  /// met with a response, the response callback is invoked to handle that
+  /// response.
+  template <typename T>
+  OutgoingMessage<T> outgoingRequest(
+      llvm::StringLiteral method, llvm::json::Value id,
+      llvm::unique_function<void(llvm::Expected<llvm::json::Value>)> callback) {
+    responseHandlers.insert(
+        {getIDAsString(id), std::make_pair(method.str(), std::move(callback))});
+
+    return [&, method, id](const T &params) {
+      std::lock_guard<std::mutex> transportLock(transportOutputMutex);
+      Logger::info("--> {0}", method);
+      transport.call(method, llvm::json::Value(params), id);
+    };
+  }
+
 private:
+  /// Returns a string representation of a message ID, which is specified as
+  /// `integer | string | null`.
+  static std::string getIDAsString(llvm::json::Value id);
+
   template <typename HandlerT>
   using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
 
@@ -179,6 +201,14 @@ class MessageHandler {
   HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
       methodHandlers;
 
+  /// A pair of (1) the original request's method name, and (2) the callback
+  /// function to be invoked for responses.
+  using ResponseHandlerTy =
+      std::pair<std::string,
+                llvm::unique_function<void(llvm::Expected<llvm::json::Value>)>>;
+  /// A mapping from request/response ID to response handler.
+  llvm::StringMap<ResponseHandlerTy> responseHandlers;
+
   JSONTransport &transport;
 
   /// Mutex to guard sending output messages to the transport.
diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp
index 339c5f3825165d..e250e9e9b15e95 100644
--- a/mlir/lib/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp
@@ -117,24 +117,28 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
 
 bool MessageHandler::onReply(llvm::json::Value id,
                              llvm::Expected<llvm::json::Value> result) {
-  // TODO: Add support for reply callbacks when support for outgoing messages is
-  // added. For now, we just log an error on any replies received.
-  Callback<llvm::json::Value> replyHandler =
-      [&id](llvm::Expected<llvm::json::Value> result) {
-        Logger::error(
-            "received a reply with ID {0}, but there was no such call", id);
-        if (!result)
-          llvm::consumeError(result.takeError());
-      };
-
-  // Log and run the reply handler.
-  if (result)
-    replyHandler(std::move(result));
-  else
-    replyHandler(result.takeError());
+  auto it = responseHandlers.find(getIDAsString(id));
+  if (it != responseHandlers.end()) {
+    Logger::info("--> reply:{0}({1})", it->second.first, id);
+    it->second.second(std::move(result));
+  } else {
+    Logger::error(
+        "received a reply with ID {0}, but there was no such outgoing request",
+        id);
+    if (!result)
+      llvm::consumeError(result.takeError());
+  }
   return true;
 }
 
+std::string MessageHandler::getIDAsString(llvm::json::Value id) {
+  std::string result;
+  llvm::raw_string_ostream os(result);
+  os << id;
+  os.flush();
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // JSONTransport
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index b46f02bc4b197b..a19b277b20d171 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -125,4 +125,35 @@ TEST_F(TransportInputTest, OutgoingNotification) {
   notifyFn(CompletionList{});
   EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
 }
+
+TEST_F(TransportInputTest, ResponseHandlerNotFound) {
+  // Unhandled responses are only reported via error logging. As a result, this
+  // test can't make any expectations -- but it prints the output anyway, by way
+  // of demonstration.
+  Logger::setLogLevel(Logger::Level::Error);
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"params\":null}\n");
+  runTransport();
+}
+
+TEST_F(TransportInputTest, OutgoingRequest) {
+  Logger::setLogLevel(Logger::Level::Debug);
+
+  // Make an outgoing request.
+  bool responseCallbackInvoked = false;
+  auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
+      "outgoing-request", 82,
+      [&responseCallbackInvoked](llvm::Expected<llvm::json::Value> value) {
+        ASSERT_TRUE((bool)value);
+        responseCallbackInvoked = true;
+      });
+  callFn(CompletionList{});
+  EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
+  EXPECT_FALSE(responseCallbackInvoked);
+
+  // The request receives a response. The message handler handles this response
+  // by invoking the callback from above.
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":82,\"params\":null}\n");
+  runTransport();
+  EXPECT_TRUE(responseCallbackInvoked);
+}
 } // namespace



More information about the Mlir-commits mailing list