[Mlir-commits] [mlir] e24a7bb - [mlir-lsp] Support outgoing requests (#90078)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 29 15:43:13 PDT 2024


Author: Brian Gesiak
Date: 2024-04-29T18:43:08-04:00
New Revision: e24a7bbf4515213f44d410bfc41b3dff27c49c86

URL: https://github.com/llvm/llvm-project/commit/e24a7bbf4515213f44d410bfc41b3dff27c49c86
DIFF: https://github.com/llvm/llvm-project/commit/e24a7bbf4515213f44d410bfc41b3dff27c49c86.diff

LOG: [mlir-lsp] Support outgoing requests (#90078)

Add support for outgoing requests to `lsp::MessageHandler`. Much like
`MessageHandler::outgoingNotification`, this allows for the message
handler to send outgoing messages via its JSON transport, but in this
case, those messages are requests, not notifications.

Requests receive responses (also referred to as "replies" in
`MLIRLspServerSupportLib`). These were previously unsupported, and
`lsp::MessageHandler` would log an error each time it processed a JSON
message that appeared to be a response (something with an "id" field,
but no "method" field). However, the `outgoingRequest` method now
handles response callbacks: an outgoing request with a given ID is set
up such that a callback function is invoked when a response with that ID
is received.

Added: 
    

Modified: 
    mlir/include/mlir/Tools/lsp-server-support/Transport.h
    mlir/lib/Tools/lsp-server-support/Transport.cpp
    mlir/unittests/Tools/lsp-server-support/Transport.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index 44c71058cf717c..047d174234df8d 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
 #define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
 
+#include "mlir/Support/DebugStringHelper.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Tools/lsp-server-support/Logging.h"
@@ -100,6 +101,18 @@ using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
 template <typename T>
 using OutgoingNotification = llvm::unique_function<void(const T &)>;
 
+/// An OutgoingRequest<T> is a function used for outgoing requests to send to
+/// the client.
+template <typename T>
+using OutgoingRequest =
+    llvm::unique_function<void(const T &, llvm::json::Value id)>;
+
+/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
+/// client receives a response in turn. It is passed the original request's ID,
+/// as well as the result JSON.
+using OutgoingRequestCallback =
+    std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
+
 /// A handler used to process the incoming transport messages.
 class MessageHandler {
 public:
@@ -170,6 +183,26 @@ class MessageHandler {
     };
   }
 
+  /// Create an OutgoingRequest function that, when called, sends a request with
+  /// the given method via the transport. Should the outgoing request be
+  /// met with a response, the response callback is invoked to handle that
+  /// response.
+  template <typename T>
+  OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
+                                     OutgoingRequestCallback callback) {
+    return [&, method, callback](const T &params, llvm::json::Value id) {
+      {
+        std::lock_guard<std::mutex> lock(responseHandlersMutex);
+        responseHandlers.insert(
+            {debugString(id), std::make_pair(method.str(), callback)});
+      }
+
+      std::lock_guard<std::mutex> transportLock(transportOutputMutex);
+      Logger::info("--> {0}({1})", method, id);
+      transport.call(method, llvm::json::Value(params), id);
+    };
+  }
+
 private:
   template <typename HandlerT>
   using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
@@ -178,6 +211,14 @@ class MessageHandler {
   HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
       methodHandlers;
 
+  /// A pair of (1) the original request's method name, and (2) the callback
+  /// function to be invoked for responses.
+  using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
+  /// A mapping from request/response ID to response handler.
+  llvm::StringMap<ResponseHandlerTy> responseHandlers;
+  /// Mutex to guard insertion into the response handler map.
+  std::mutex responseHandlersMutex;
+
   JSONTransport &transport;
 
   /// Mutex to guard sending output messages to the transport.

diff  --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp
index 339c5f3825165d..1e90ab32281f54 100644
--- a/mlir/lib/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp
@@ -117,21 +117,29 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
 
 bool MessageHandler::onReply(llvm::json::Value id,
                              llvm::Expected<llvm::json::Value> result) {
-  // TODO: Add support for reply callbacks when support for outgoing messages is
-  // added. For now, we just log an error on any replies received.
-  Callback<llvm::json::Value> replyHandler =
-      [&id](llvm::Expected<llvm::json::Value> result) {
-        Logger::error(
-            "received a reply with ID {0}, but there was no such call", id);
-        if (!result)
-          llvm::consumeError(result.takeError());
-      };
-
-  // Log and run the reply handler.
-  if (result)
-    replyHandler(std::move(result));
-  else
-    replyHandler(result.takeError());
+  // Find the response handler in the mapping. If it exists, move it out of the
+  // mapping and erase it.
+  ResponseHandlerTy responseHandler;
+  {
+    std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex);
+    auto it = responseHandlers.find(debugString(id));
+    if (it != responseHandlers.end()) {
+      responseHandler = std::move(it->second);
+      responseHandlers.erase(it);
+    }
+  }
+
+  // If we found a response handler, invoke it. Otherwise, log an error.
+  if (responseHandler.second) {
+    Logger::info("--> reply:{0}({1})", responseHandler.first, id);
+    responseHandler.second(std::move(id), std::move(result));
+  } else {
+    Logger::error(
+        "received a reply with ID {0}, but there was no such outgoing request",
+        id);
+    if (!result)
+      llvm::consumeError(result.takeError());
+  }
   return true;
 }
 

diff  --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index a086964cd3660c..fee21840595232 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -131,4 +131,43 @@ TEST_F(TransportInputTest, OutgoingNotification) {
   notifyFn(CompletionList{});
   EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
 }
+
+TEST_F(TransportInputTest, ResponseHandlerNotFound) {
+  // Unhandled responses are only reported via error logging. As a result, this
+  // test can't make any expectations -- but it prints the output anyway, by way
+  // of demonstration.
+  Logger::setLogLevel(Logger::Level::Error);
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"result\":null}\n");
+  runTransport();
+}
+
+TEST_F(TransportInputTest, OutgoingRequest) {
+  // Make some outgoing requests.
+  int responseCallbackInvoked = 0;
+  auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
+      "outgoing-request",
+      [&responseCallbackInvoked](llvm::json::Value id,
+                                 llvm::Expected<llvm::json::Value> value) {
+        // Make expectations on the expected response.
+        EXPECT_EQ(id, 83);
+        ASSERT_TRUE((bool)value);
+        EXPECT_EQ(debugString(*value), "{\"foo\":6}");
+        responseCallbackInvoked += 1;
+        llvm::outs() << "here!!!\n";
+      });
+  callFn({}, 82);
+  callFn({}, 83);
+  callFn({}, 84);
+  EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
+  EXPECT_EQ(responseCallbackInvoked, 0);
+
+  // One of the requests receives a response. The message handler handles this
+  // response by invoking the callback from above. Subsequent responses with the
+  // same ID are ignored.
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
+             "// -----\n"
+             "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
+  runTransport();
+  EXPECT_EQ(responseCallbackInvoked, 1);
+}
 } // namespace


        


More information about the Mlir-commits mailing list