[Mlir-commits] [mlir] [mlir-lsp] Parse outgoing request callback JSON (PR #90693)

Brian Gesiak llvmlistbot at llvm.org
Tue Apr 30 18:35:43 PDT 2024


https://github.com/modocache created https://github.com/llvm/llvm-project/pull/90693

Rather than force callbacks for outgoing requests to parse the result JSON themselves (of type `llvm::Expected<llvm::json::Value>`), allow users to specify the result type, which
`MessageHandler::outgoingRequest` will parse for them. This eliminates boilerplate for users sending outgoing requests.

>From 5187904eb8eee8e2cc9fdbc47161c49e6d43a0ce Mon Sep 17 00:00:00 2001
From: Brian Gesiak <brian at modocache.io>
Date: Tue, 30 Apr 2024 21:32:19 -0400
Subject: [PATCH] [mlir-lsp] Parse outgoing request callback JSON

Rather than force callbacks for outgoing requests to parse the result
JSON themselves (of type `llvm::Expected<llvm::json::Value>`), allow
users to specify the result type, which
`MessageHandler::outgoingRequest` will parse for them. This eliminates
boilerplate for users sending outgoing requests.
---
 .../mlir/Tools/lsp-server-support/Transport.h | 42 ++++++++++----
 .../Tools/lsp-server-support/Transport.cpp    | 58 ++++++++++++++-----
 2 files changed, 76 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index 047d174234df8d..b2979be60eacc8 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -109,9 +109,10 @@ using OutgoingRequest =
 
 /// An `OutgoingRequestCallback` is invoked when an outgoing request to the
 /// client receives a response in turn. It is passed the original request's ID,
-/// as well as the result JSON.
+/// as well as the response result.
+template <typename T>
 using OutgoingRequestCallback =
-    std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
+    std::function<void(llvm::json::Value, llvm::Expected<T>)>;
 
 /// A handler used to process the incoming transport messages.
 class MessageHandler {
@@ -185,21 +186,37 @@ class MessageHandler {
 
   /// Create an OutgoingRequest function that, when called, sends a request with
   /// the given method via the transport. Should the outgoing request be
-  /// met with a response, the response callback is invoked to handle that
-  /// response.
-  template <typename T>
-  OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
-                                     OutgoingRequestCallback callback) {
-    return [&, method, callback](const T &params, llvm::json::Value id) {
+  /// met with a response, the result JSON is parsed and the response callback
+  /// is invoked.
+  template <typename Param, typename Result>
+  OutgoingRequest<Param>
+  outgoingRequest(llvm::StringLiteral method,
+                  OutgoingRequestCallback<Result> callback) {
+    return [&, method, callback](const Param &param, llvm::json::Value id) {
+      auto callbackWrapper = [method, callback = std::move(callback)](
+                                 llvm::json::Value id,
+                                 llvm::Expected<llvm::json::Value> value) {
+        if (!value)
+          return callback(std::move(id), value.takeError());
+
+        std::string responseName = llvm::formatv("reply:{0}({1})", method, id);
+        llvm::Expected<Result> result =
+            parse<Result>(*value, responseName, "response");
+        if (!result)
+          return callback(std::move(id), result.takeError());
+
+        return callback(std::move(id), *result);
+      };
+
       {
         std::lock_guard<std::mutex> lock(responseHandlersMutex);
         responseHandlers.insert(
-            {debugString(id), std::make_pair(method.str(), callback)});
+            {debugString(id), std::make_pair(method.str(), callbackWrapper)});
       }
 
       std::lock_guard<std::mutex> transportLock(transportOutputMutex);
       Logger::info("--> {0}({1})", method, id);
-      transport.call(method, llvm::json::Value(params), id);
+      transport.call(method, llvm::json::Value(param), id);
     };
   }
 
@@ -213,7 +230,10 @@ class MessageHandler {
 
   /// A pair of (1) the original request's method name, and (2) the callback
   /// function to be invoked for responses.
-  using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
+  using ResponseHandlerTy =
+      std::pair<std::string,
+                std::function<void(llvm::json::Value,
+                                   llvm::Expected<llvm::json::Value>)>>;
   /// A mapping from request/response ID to response handler.
   llvm::StringMap<ResponseHandlerTy> responseHandlers;
   /// Mutex to guard insertion into the response handler map.
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index fee21840595232..0303c1cba8bc87 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -144,17 +144,17 @@ TEST_F(TransportInputTest, ResponseHandlerNotFound) {
 TEST_F(TransportInputTest, OutgoingRequest) {
   // Make some outgoing requests.
   int responseCallbackInvoked = 0;
-  auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
-      "outgoing-request",
-      [&responseCallbackInvoked](llvm::json::Value id,
-                                 llvm::Expected<llvm::json::Value> value) {
-        // Make expectations on the expected response.
-        EXPECT_EQ(id, 83);
-        ASSERT_TRUE((bool)value);
-        EXPECT_EQ(debugString(*value), "{\"foo\":6}");
-        responseCallbackInvoked += 1;
-        llvm::outs() << "here!!!\n";
-      });
+  auto callFn =
+      getMessageHandler().outgoingRequest<CompletionList, CompletionContext>(
+          "outgoing-request",
+          [&responseCallbackInvoked](llvm::json::Value id,
+                                     llvm::Expected<CompletionContext> result) {
+            // Make expectations on the expected response.
+            EXPECT_EQ(id, 83);
+            ASSERT_TRUE((bool)result);
+            EXPECT_EQ(result->triggerKind, CompletionTriggerKind::Invoked);
+            responseCallbackInvoked += 1;
+          });
   callFn({}, 82);
   callFn({}, 83);
   callFn({}, 84);
@@ -164,9 +164,41 @@ TEST_F(TransportInputTest, OutgoingRequest) {
   // One of the requests receives a response. The message handler handles this
   // response by invoking the callback from above. Subsequent responses with the
   // same ID are ignored.
-  writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
+  writeInput(
+      "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":1}}\n"
+      "// -----\n"
+      "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":3}}\n");
+  runTransport();
+  EXPECT_EQ(responseCallbackInvoked, 1);
+}
+
+TEST_F(TransportInputTest, OutgoingRequestJSONParseFailure) {
+  // Make an outgoing request that expects a failure response.
+  bool responseCallbackInvoked = 0;
+  auto callFn = getMessageHandler().outgoingRequest<CompletionList, Position>(
+      "outgoing-request-json-parse-failure",
+      [&responseCallbackInvoked](llvm::json::Value id,
+                                 llvm::Expected<Position> result) {
+        llvm::Error err = result.takeError();
+        EXPECT_EQ(id, 109);
+        ASSERT_TRUE((bool)err);
+        EXPECT_THAT(debugString(err),
+                    HasSubstr("failed to decode "
+                              "reply:outgoing-request-json-parse-failure(109) "
+                              "response: missing value at (root).character"));
+        llvm::consumeError(std::move(err));
+        responseCallbackInvoked += 1;
+      });
+  callFn({}, 109);
+  EXPECT_EQ(responseCallbackInvoked, 0);
+
+  // The request receives multiple responses, but only the first one triggers
+  // the response callback. The first response has erroneous JSON that causes a
+  // parse failure.
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":7}}\n"
              "// -----\n"
-             "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
+             "{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":3,"
+             "\"character\":2}}\n");
   runTransport();
   EXPECT_EQ(responseCallbackInvoked, 1);
 }



More information about the Mlir-commits mailing list