[Mlir-commits] [mlir] [mlir-lsp] Support outgoing requests (PR #90078)

Brian Gesiak llvmlistbot at llvm.org
Thu Apr 25 20:08:34 PDT 2024


https://github.com/modocache updated https://github.com/llvm/llvm-project/pull/90078

>From 529d2c3308dc605dbc9867850b8ca1fb8867ed2c Mon Sep 17 00:00:00 2001
From: Brian Gesiak <brian at modocache.io>
Date: Thu, 25 Apr 2024 09:34:55 -0400
Subject: [PATCH] [mlir-lsp] Support outgoing requests

Add support for outgoing requests to `lsp::MessageHandler`. Much like
`MessageHandler::outgoingNotification`, this allows for the message
handler to send outgoing messages via its JSON transport, but in this
case, those messages are requests, not notifications.

Requests receive responses (also referred to as "replies" in
`MLIRLspServerSupportLib`). These were previously unsupported, and
`lsp::MessageHandler` would log an error each time it processed a JSON
message that appeared to be a response (something with an "id" field,
but no "method" field). However, the `outgoingRequest` method now
handles response callbacks: an outgoing request with a given ID is set
up such that a callback function is invoked when a response with that ID
is received.
---
 .../mlir/Tools/lsp-server-support/Transport.h | 41 +++++++++++++++++++
 .../Tools/lsp-server-support/Transport.cpp    | 38 ++++++++++-------
 .../Tools/lsp-server-support/Transport.cpp    | 39 ++++++++++++++++++
 3 files changed, 103 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index 44c71058cf717c..047d174234df8d 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
 #define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
 
+#include "mlir/Support/DebugStringHelper.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Tools/lsp-server-support/Logging.h"
@@ -100,6 +101,18 @@ using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
 template <typename T>
 using OutgoingNotification = llvm::unique_function<void(const T &)>;
 
+/// An OutgoingRequest<T> is a function used for outgoing requests to send to
+/// the client.
+template <typename T>
+using OutgoingRequest =
+    llvm::unique_function<void(const T &, llvm::json::Value id)>;
+
+/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
+/// client receives a response in turn. It is passed the original request's ID,
+/// as well as the result JSON.
+using OutgoingRequestCallback =
+    std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
+
 /// A handler used to process the incoming transport messages.
 class MessageHandler {
 public:
@@ -170,6 +183,26 @@ class MessageHandler {
     };
   }
 
+  /// Create an OutgoingRequest function that, when called, sends a request with
+  /// the given method via the transport. Should the outgoing request be
+  /// met with a response, the response callback is invoked to handle that
+  /// response.
+  template <typename T>
+  OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
+                                     OutgoingRequestCallback callback) {
+    return [&, method, callback](const T &params, llvm::json::Value id) {
+      {
+        std::lock_guard<std::mutex> lock(responseHandlersMutex);
+        responseHandlers.insert(
+            {debugString(id), std::make_pair(method.str(), callback)});
+      }
+
+      std::lock_guard<std::mutex> transportLock(transportOutputMutex);
+      Logger::info("--> {0}({1})", method, id);
+      transport.call(method, llvm::json::Value(params), id);
+    };
+  }
+
 private:
   template <typename HandlerT>
   using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
@@ -178,6 +211,14 @@ class MessageHandler {
   HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
       methodHandlers;
 
+  /// A pair of (1) the original request's method name, and (2) the callback
+  /// function to be invoked for responses.
+  using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
+  /// A mapping from request/response ID to response handler.
+  llvm::StringMap<ResponseHandlerTy> responseHandlers;
+  /// Mutex to guard insertion into the response handler map.
+  std::mutex responseHandlersMutex;
+
   JSONTransport &transport;
 
   /// Mutex to guard sending output messages to the transport.
diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp
index 339c5f3825165d..1e90ab32281f54 100644
--- a/mlir/lib/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp
@@ -117,21 +117,29 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
 
 bool MessageHandler::onReply(llvm::json::Value id,
                              llvm::Expected<llvm::json::Value> result) {
-  // TODO: Add support for reply callbacks when support for outgoing messages is
-  // added. For now, we just log an error on any replies received.
-  Callback<llvm::json::Value> replyHandler =
-      [&id](llvm::Expected<llvm::json::Value> result) {
-        Logger::error(
-            "received a reply with ID {0}, but there was no such call", id);
-        if (!result)
-          llvm::consumeError(result.takeError());
-      };
-
-  // Log and run the reply handler.
-  if (result)
-    replyHandler(std::move(result));
-  else
-    replyHandler(result.takeError());
+  // Find the response handler in the mapping. If it exists, move it out of the
+  // mapping and erase it.
+  ResponseHandlerTy responseHandler;
+  {
+    std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex);
+    auto it = responseHandlers.find(debugString(id));
+    if (it != responseHandlers.end()) {
+      responseHandler = std::move(it->second);
+      responseHandlers.erase(it);
+    }
+  }
+
+  // If we found a response handler, invoke it. Otherwise, log an error.
+  if (responseHandler.second) {
+    Logger::info("--> reply:{0}({1})", responseHandler.first, id);
+    responseHandler.second(std::move(id), std::move(result));
+  } else {
+    Logger::error(
+        "received a reply with ID {0}, but there was no such outgoing request",
+        id);
+    if (!result)
+      llvm::consumeError(result.takeError());
+  }
   return true;
 }
 
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index b46f02bc4b197b..ae87466318df12 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -125,4 +125,43 @@ TEST_F(TransportInputTest, OutgoingNotification) {
   notifyFn(CompletionList{});
   EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
 }
+
+TEST_F(TransportInputTest, ResponseHandlerNotFound) {
+  // Unhandled responses are only reported via error logging. As a result, this
+  // test can't make any expectations -- but it prints the output anyway, by way
+  // of demonstration.
+  Logger::setLogLevel(Logger::Level::Error);
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"result\":null}\n");
+  runTransport();
+}
+
+TEST_F(TransportInputTest, OutgoingRequest) {
+  // Make some outgoing requests.
+  int responseCallbackInvoked = 0;
+  auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
+      "outgoing-request",
+      [&responseCallbackInvoked](llvm::json::Value id,
+                                 llvm::Expected<llvm::json::Value> value) {
+        // Make expectations on the expected response.
+        EXPECT_EQ(id, 83);
+        ASSERT_TRUE((bool)value);
+        EXPECT_EQ(debugString(*value), "{\"foo\":6}");
+        responseCallbackInvoked += 1;
+        llvm::outs() << "here!!!\n";
+      });
+  callFn({}, 82);
+  callFn({}, 83);
+  callFn({}, 84);
+  EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
+  EXPECT_EQ(responseCallbackInvoked, 0);
+
+  // One of the requests receives a response. The message handler handles this
+  // response by invoking the callback from above. Subsequent responses with the
+  // same ID are ignored.
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
+             "// -----\n"
+             "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
+  runTransport();
+  EXPECT_EQ(responseCallbackInvoked, 1);
+}
 } // namespace



More information about the Mlir-commits mailing list