[Mlir-commits] [mlir] [mlir-lsp] Support outgoing requests (PR #90078)

Brian Gesiak llvmlistbot at llvm.org
Thu Apr 25 12:46:55 PDT 2024


https://github.com/modocache updated https://github.com/llvm/llvm-project/pull/90078

>From 1f75c15006beca6da50018caa13ff9e90635283a Mon Sep 17 00:00:00 2001
From: Brian Gesiak <brian at modocache.io>
Date: Thu, 25 Apr 2024 09:34:55 -0400
Subject: [PATCH] [mlir-lsp] Support outgoing requests

Add support for outgoing requests to `lsp::MessageHandler`. Much like
`MessageHandler::outgoingNotification`, this allows for the message
handler to send outgoing messages via its JSON transport, but in this
case, those messages are requests, not notifications.

Requests receive responses (also referred to as "replies" in
`MLIRLspServerSupportLib`). These were previously unsupported, and
`lsp::MessageHandler` would log an error each time it processed a JSON
message that appeared to be a response (something with an "id" field,
but no "method" field). However, the `outgoingRequest` method now
handles response callbacks: an outgoing request with a given ID is set
up such that a callback function is invoked when a response with that ID
is received.
---
 .../mlir/Tools/lsp-server-support/Transport.h | 27 ++++++++++++++++
 .../Tools/lsp-server-support/Transport.cpp    | 26 +++++++---------
 .../Tools/lsp-server-support/Transport.cpp    | 31 +++++++++++++++++++
 3 files changed, 69 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index b973a2e267251a..973b4f7e84ed41 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
 #define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
 
+#include "mlir/Support/DebugStringHelper.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Tools/lsp-server-support/Logging.h"
@@ -171,6 +172,24 @@ class MessageHandler {
     };
   }
 
+  /// Create an OutgoingMessage function that, when called, sends a request with
+  /// the given method and ID via the transport. Should the outgoing request be
+  /// met with a response, the response callback is invoked to handle that
+  /// response.
+  template <typename T>
+  OutgoingMessage<T> outgoingRequest(
+      llvm::StringLiteral method, llvm::json::Value id,
+      llvm::unique_function<void(llvm::Expected<llvm::json::Value>)> callback) {
+    responseHandlers.insert(
+        {debugString(id), std::make_pair(method.str(), std::move(callback))});
+
+    return [&, method, id](const T &params) {
+      std::lock_guard<std::mutex> transportLock(transportOutputMutex);
+      Logger::info("--> {0}", method);
+      transport.call(method, llvm::json::Value(params), id);
+    };
+  }
+
 private:
   template <typename HandlerT>
   using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
@@ -179,6 +198,14 @@ class MessageHandler {
   HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
       methodHandlers;
 
+  /// A pair of (1) the original request's method name, and (2) the callback
+  /// function to be invoked for responses.
+  using ResponseHandlerTy =
+      std::pair<std::string,
+                llvm::unique_function<void(llvm::Expected<llvm::json::Value>)>>;
+  /// A mapping from request/response ID to response handler.
+  llvm::StringMap<ResponseHandlerTy> responseHandlers;
+
   JSONTransport &transport;
 
   /// Mutex to guard sending output messages to the transport.
diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp
index 339c5f3825165d..acaad3f35a5990 100644
--- a/mlir/lib/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp
@@ -117,21 +117,17 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
 
 bool MessageHandler::onReply(llvm::json::Value id,
                              llvm::Expected<llvm::json::Value> result) {
-  // TODO: Add support for reply callbacks when support for outgoing messages is
-  // added. For now, we just log an error on any replies received.
-  Callback<llvm::json::Value> replyHandler =
-      [&id](llvm::Expected<llvm::json::Value> result) {
-        Logger::error(
-            "received a reply with ID {0}, but there was no such call", id);
-        if (!result)
-          llvm::consumeError(result.takeError());
-      };
-
-  // Log and run the reply handler.
-  if (result)
-    replyHandler(std::move(result));
-  else
-    replyHandler(result.takeError());
+  auto it = responseHandlers.find(debugString(id));
+  if (it != responseHandlers.end()) {
+    Logger::info("--> reply:{0}({1})", it->second.first, id);
+    it->second.second(std::move(result));
+  } else {
+    Logger::error(
+        "received a reply with ID {0}, but there was no such outgoing request",
+        id);
+    if (!result)
+      llvm::consumeError(result.takeError());
+  }
   return true;
 }
 
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index b46f02bc4b197b..a19b277b20d171 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -125,4 +125,35 @@ TEST_F(TransportInputTest, OutgoingNotification) {
   notifyFn(CompletionList{});
   EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
 }
+
+TEST_F(TransportInputTest, ResponseHandlerNotFound) {
+  // Unhandled responses are only reported via error logging. As a result, this
+  // test can't make any expectations -- but it prints the output anyway, by way
+  // of demonstration.
+  Logger::setLogLevel(Logger::Level::Error);
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"params\":null}\n");
+  runTransport();
+}
+
+TEST_F(TransportInputTest, OutgoingRequest) {
+  Logger::setLogLevel(Logger::Level::Debug);
+
+  // Make an outgoing request.
+  bool responseCallbackInvoked = false;
+  auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
+      "outgoing-request", 82,
+      [&responseCallbackInvoked](llvm::Expected<llvm::json::Value> value) {
+        ASSERT_TRUE((bool)value);
+        responseCallbackInvoked = true;
+      });
+  callFn(CompletionList{});
+  EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
+  EXPECT_FALSE(responseCallbackInvoked);
+
+  // The request receives a response. The message handler handles this response
+  // by invoking the callback from above.
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":82,\"params\":null}\n");
+  runTransport();
+  EXPECT_TRUE(responseCallbackInvoked);
+}
 } // namespace



More information about the Mlir-commits mailing list