kadircet updated this revision to Diff 160636.
kadircet marked 4 inline comments as done.
kadircet added a comment.

- Get rid of getCancellationError.
- Add replyError helper.


Repository:
  rCTE Clang Tools Extra

https://reviews.llvm.org/D50502

Files:
  clangd/CMakeLists.txt
  clangd/Cancellation.cpp
  clangd/Cancellation.h
  clangd/ClangdLSPServer.cpp
  clangd/ClangdLSPServer.h
  clangd/ClangdServer.cpp
  clangd/ClangdServer.h
  clangd/JSONRPCDispatcher.cpp
  clangd/JSONRPCDispatcher.h
  clangd/Protocol.cpp
  clangd/Protocol.h
  clangd/ProtocolHandlers.cpp
  clangd/ProtocolHandlers.h
  unittests/clangd/CMakeLists.txt
  unittests/clangd/CancellationTests.cpp

Index: unittests/clangd/CancellationTests.cpp
===================================================================
--- /dev/null
+++ unittests/clangd/CancellationTests.cpp
@@ -0,0 +1,70 @@
+#include "Cancellation.h"
+#include "Context.h"
+#include "llvm/Support/Error.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include <atomic>
+#include <iostream>
+#include <memory>
+
+namespace clang {
+namespace clangd {
+namespace {
+
+TEST(CancellationTest, CancellationTest) {
+  {
+    TaskHandle TH = TaskHandle::createCancellableTaskHandle();
+    WithContext ContextWithCancellation(
+        CancellationHandler::setCurrentCancellationToken(TH));
+    EXPECT_FALSE(CancellationHandler::isCancelled());
+    TH.cancel();
+    EXPECT_TRUE(CancellationHandler::isCancelled());
+  }
+  EXPECT_FALSE(CancellationHandler::isCancelled());
+}
+
+TEST(CancellationTest, TaskHandleTestHandleDiesContextLives) {
+  llvm::Optional<WithContext> ContextWithCancellation;
+  {
+    auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle();
+    ContextWithCancellation.emplace(
+        CancellationHandler::setCurrentCancellationToken(
+            CancellableTaskHandle));
+    EXPECT_FALSE(CancellationHandler::isCancelled());
+    CancellableTaskHandle.cancel();
+    EXPECT_TRUE(CancellationHandler::isCancelled());
+  }
+  EXPECT_TRUE(CancellationHandler::isCancelled());
+  ContextWithCancellation.reset();
+  EXPECT_FALSE(CancellationHandler::isCancelled());
+}
+
+TEST(CancellationTest, TaskHandleContextDiesHandleLives) {
+  {
+    auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle();
+    {
+      WithContext ContextWithCancellation(
+          CancellationHandler::setCurrentCancellationToken(
+              CancellableTaskHandle));
+      EXPECT_FALSE(CancellationHandler::isCancelled());
+      CancellableTaskHandle.cancel();
+      EXPECT_TRUE(CancellationHandler::isCancelled());
+    }
+  }
+  EXPECT_FALSE(CancellationHandler::isCancelled());
+}
+
+TEST(CancellationTest, CancellationToken) {
+  auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle();
+  WithContext ContextWithCancellation(
+      CancellationHandler::setCurrentCancellationToken(CancellableTaskHandle));
+  auto CT = CancellationHandler::isCancelled();
+  EXPECT_FALSE(CT);
+  CancellableTaskHandle.cancel();
+  EXPECT_TRUE(CT);
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: unittests/clangd/CMakeLists.txt
===================================================================
--- unittests/clangd/CMakeLists.txt
+++ unittests/clangd/CMakeLists.txt
@@ -10,6 +10,7 @@
 
 add_extra_unittest(ClangdTests
   Annotations.cpp
+  CancellationTests.cpp
   ClangdTests.cpp
   ClangdUnitTests.cpp
   CodeCompleteTests.cpp
Index: clangd/ProtocolHandlers.h
===================================================================
--- clangd/ProtocolHandlers.h
+++ clangd/ProtocolHandlers.h
@@ -55,6 +55,7 @@
   virtual void onDocumentHighlight(TextDocumentPositionParams &Params) = 0;
   virtual void onHover(TextDocumentPositionParams &Params) = 0;
   virtual void onChangeConfiguration(DidChangeConfigurationParams &Params) = 0;
+  virtual void onCancelRequest(CancelParams &Params) = 0;
 };
 
 void registerCallbackHandlers(JSONRPCDispatcher &Dispatcher,
Index: clangd/ProtocolHandlers.cpp
===================================================================
--- clangd/ProtocolHandlers.cpp
+++ clangd/ProtocolHandlers.cpp
@@ -75,4 +75,5 @@
   Register("workspace/didChangeConfiguration",
            &ProtocolCallbacks::onChangeConfiguration);
   Register("workspace/symbol", &ProtocolCallbacks::onWorkspaceSymbol);
+  Register("$/cancelRequest", &ProtocolCallbacks::onCancelRequest);
 }
Index: clangd/Protocol.h
===================================================================
--- clangd/Protocol.h
+++ clangd/Protocol.h
@@ -861,6 +861,13 @@
 llvm::json::Value toJSON(const DocumentHighlight &DH);
 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const DocumentHighlight &);
 
+struct CancelParams {
+  std::string ID;
+};
+llvm::json::Value toJSON(const CancelParams &);
+llvm::raw_ostream &operator<<(llvm::raw_ostream &, const CancelParams &);
+bool fromJSON(const llvm::json::Value &, CancelParams &);
+
 } // namespace clangd
 } // namespace clang
 
Index: clangd/Protocol.cpp
===================================================================
--- clangd/Protocol.cpp
+++ clangd/Protocol.cpp
@@ -615,5 +615,30 @@
          O.map("compilationDatabaseChanges", CCPC.compilationDatabaseChanges);
 }
 
+json::Value toJSON(const CancelParams &CP) {
+  return json::Object{{"id", CP.ID}};
+}
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &O, const CancelParams &CP) {
+  O << toJSON(CP);
+  return O;
+}
+
+bool fromJSON(const json::Value &Params, CancelParams &CP) {
+  json::ObjectMapper O(Params);
+  if (!O)
+    return false;
+  // ID is either a number or a string, check for both.
+  if (O.map("id", CP.ID))
+    return true;
+
+  int64_t id_number;
+  if (O.map("id", id_number)) {
+    CP.ID = utostr(id_number);
+    return true;
+  }
+  return false;
+}
+
 } // namespace clangd
 } // namespace clang
Index: clangd/JSONRPCDispatcher.h
===================================================================
--- clangd/JSONRPCDispatcher.h
+++ clangd/JSONRPCDispatcher.h
@@ -64,6 +64,7 @@
 /// Sends an error response to the client, and logs it.
 /// Current context must derive from JSONRPCDispatcher::Handler.
 void replyError(ErrorCode Code, const llvm::StringRef &Message);
+void replyError(llvm::Error E);
 /// Sends a request to the client.
 /// Current context must derive from JSONRPCDispatcher::Handler.
 void call(llvm::StringRef Method, llvm::json::Value &&Params);
@@ -111,6 +112,7 @@
                            JSONStreamStyle InputStyle,
                            JSONRPCDispatcher &Dispatcher, bool &IsDone);
 
+const llvm::json::Value *GetRequestId();
 } // namespace clangd
 } // namespace clang
 
Index: clangd/JSONRPCDispatcher.cpp
===================================================================
--- clangd/JSONRPCDispatcher.cpp
+++ clangd/JSONRPCDispatcher.cpp
@@ -8,6 +8,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "JSONRPCDispatcher.h"
+#include "Cancellation.h"
 #include "ProtocolHandlers.h"
 #include "Trace.h"
 #include "llvm/ADT/SmallString.h"
@@ -129,6 +130,14 @@
   }
 }
 
+void clangd::replyError(Error E) {
+  Error Err = handleErrors(std::move(E), [](const TaskCancelledError &TCE) {
+    replyError(ErrorCode::RequestCancelled, TCE.message());
+  });
+  if (Err)
+    replyError(ErrorCode::InvalidParams, llvm::toString(std::move(Err)));
+}
+
 void clangd::call(StringRef Method, json::Value &&Params) {
   RequestSpan::attach([&](json::Object &Args) {
     Args["Call"] = json::Object{{"method", Method.str()}, {"params", Params}};
@@ -366,3 +375,7 @@
     }
   }
 }
+
+const json::Value *clangd::GetRequestId() {
+  return Context::current().get(RequestID);
+}
Index: clangd/ClangdServer.h
===================================================================
--- clangd/ClangdServer.h
+++ clangd/ClangdServer.h
@@ -10,6 +10,7 @@
 #ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H
 #define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CLANGDSERVER_H
 
+#include "Cancellation.h"
 #include "ClangdUnit.h"
 #include "CodeComplete.h"
 #include "FSProvider.h"
@@ -122,9 +123,9 @@
   /// while returned future is not yet ready.
   /// A version of `codeComplete` that runs \p Callback on the processing thread
   /// when codeComplete results become available.
-  void codeComplete(PathRef File, Position Pos,
-                    const clangd::CodeCompleteOptions &Opts,
-                    Callback<CodeCompleteResult> CB);
+  TaskHandle codeComplete(PathRef File, Position Pos,
+                          const clangd::CodeCompleteOptions &Opts,
+                          Callback<CodeCompleteResult> CB);
 
   /// Provide signature help for \p File at \p Pos.  This method should only be
   /// called for tracked files.
Index: clangd/ClangdServer.cpp
===================================================================
--- clangd/ClangdServer.cpp
+++ clangd/ClangdServer.cpp
@@ -8,6 +8,7 @@
 //===-------------------------------------------------------------------===//
 
 #include "ClangdServer.h"
+#include "Cancellation.h"
 #include "CodeComplete.h"
 #include "FindSymbols.h"
 #include "Headers.h"
@@ -140,25 +141,32 @@
   WorkScheduler.remove(File);
 }
 
-void ClangdServer::codeComplete(PathRef File, Position Pos,
-                                const clangd::CodeCompleteOptions &Opts,
-                                Callback<CodeCompleteResult> CB) {
+TaskHandle ClangdServer::codeComplete(PathRef File, Position Pos,
+                                      const clangd::CodeCompleteOptions &Opts,
+                                      Callback<CodeCompleteResult> CB) {
   // Copy completion options for passing them to async task handler.
   auto CodeCompleteOpts = Opts;
   if (!CodeCompleteOpts.Index) // Respect overridden index.
     CodeCompleteOpts.Index = Index;
 
+  auto CancellableTaskHandle = TaskHandle::createCancellableTaskHandle();
   // Copy PCHs to avoid accessing this->PCHs concurrently
   std::shared_ptr<PCHContainerOperations> PCHs = this->PCHs;
   auto FS = FSProvider.getFileSystem();
-  auto Task = [PCHs, Pos, FS,
-               CodeCompleteOpts](Path File, Callback<CodeCompleteResult> CB,
-                                 llvm::Expected<InputsAndPreamble> IP) {
+  auto Task = [PCHs, Pos, FS, CodeCompleteOpts, CancellableTaskHandle](
+                  Path File, Callback<CodeCompleteResult> CB,
+                  llvm::Expected<InputsAndPreamble> IP) {
     if (!IP)
       return CB(IP.takeError());
 
     auto PreambleData = IP->Preamble;
 
+    WithContext ContextWithCancellation(
+        CancellationHandler::setCurrentCancellationToken(
+            std::move(CancellableTaskHandle)));
+    if (CancellationHandler::isCancelled()) {
+      return CB(llvm::make_error<TaskCancelledError>());
+    }
     // FIXME(ibiryukov): even if Preamble is non-null, we may want to check
     // both the old and the new version in case only one of them matches.
     CodeCompleteResult Result = clangd::codeComplete(
@@ -170,6 +178,7 @@
 
   WorkScheduler.runWithPreamble("CodeComplete", File,
                                 Bind(Task, File.str(), std::move(CB)));
+  return CancellableTaskHandle;
 }
 
 void ClangdServer::signatureHelp(PathRef File, Position Pos,
Index: clangd/ClangdLSPServer.h
===================================================================
--- clangd/ClangdLSPServer.h
+++ clangd/ClangdLSPServer.h
@@ -75,6 +75,7 @@
   void onRename(RenameParams &Parames) override;
   void onHover(TextDocumentPositionParams &Params) override;
   void onChangeConfiguration(DidChangeConfigurationParams &Params) override;
+  void onCancelRequest(CancelParams &Params) override;
 
   std::vector<Fix> getFixes(StringRef File, const clangd::Diagnostic &D);
 
@@ -167,8 +168,17 @@
   // the worker thread that may otherwise run an async callback on partially
   // destructed instance of ClangdLSPServer.
   ClangdServer Server;
-};
 
+  // Holds task handles for running requets. Key of the map is a serialized
+  // request id.
+  llvm::StringMap<TaskHandle> TaskHandles;
+  std::mutex TaskHandlesMutex;
+
+  // Following two functions are context-aware, they create and delete tokens
+  // associated with only their thread.
+  void CleanupTaskHandle();
+  void StoreTaskHandle(TaskHandle TH);
+};
 } // namespace clangd
 } // namespace clang
 
Index: clangd/ClangdLSPServer.cpp
===================================================================
--- clangd/ClangdLSPServer.cpp
+++ clangd/ClangdLSPServer.cpp
@@ -8,10 +8,12 @@
 //===---------------------------------------------------------------------===//
 
 #include "ClangdLSPServer.h"
+#include "Cancellation.h"
 #include "Diagnostics.h"
 #include "JSONRPCDispatcher.h"
 #include "SourceCode.h"
 #include "URI.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Errc.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/Path.h"
@@ -69,6 +71,12 @@
   return Defaults;
 }
 
+const std::string NormalizeRequestID(const json::Value &ID) {
+  std::string NormalizedID;
+  llvm::raw_string_ostream OS(NormalizedID);
+  OS << ID;
+  return OS.str();
+}
 } // namespace
 
 void ClangdLSPServer::onInitialize(InitializeParams &Params) {
@@ -337,17 +345,21 @@
 }
 
 void ClangdLSPServer::onCompletion(TextDocumentPositionParams &Params) {
-  Server.codeComplete(Params.textDocument.uri.file(), Params.position, CCOpts,
-                      [this](llvm::Expected<CodeCompleteResult> List) {
-                        if (!List)
-                          return replyError(ErrorCode::InvalidParams,
-                                            llvm::toString(List.takeError()));
-                        CompletionList LSPList;
-                        LSPList.isIncomplete = List->HasMore;
-                        for (const auto &R : List->Completions)
-                          LSPList.items.push_back(R.render(CCOpts));
-                        reply(std::move(LSPList));
-                      });
+  TaskHandle TH = Server.codeComplete(
+      Params.textDocument.uri.file(), Params.position, CCOpts,
+      [this](llvm::Expected<CodeCompleteResult> List) {
+        auto _ = llvm::make_scope_exit([this]() { CleanupTaskHandle(); });
+
+        if (!List) {
+          return replyError(List.takeError());
+        }
+        CompletionList LSPList;
+        LSPList.isIncomplete = List->HasMore;
+        for (const auto &R : List->Completions)
+          LSPList.items.push_back(R.render(CCOpts));
+        return reply(std::move(LSPList));
+      });
+  StoreTaskHandle(std::move(TH));
 }
 
 void ClangdLSPServer::onSignatureHelp(TextDocumentPositionParams &Params) {
@@ -362,14 +374,14 @@
 }
 
 void ClangdLSPServer::onGoToDefinition(TextDocumentPositionParams &Params) {
-  Server.findDefinitions(
-      Params.textDocument.uri.file(), Params.position,
-      [](llvm::Expected<std::vector<Location>> Items) {
-        if (!Items)
-          return replyError(ErrorCode::InvalidParams,
-                            llvm::toString(Items.takeError()));
-        reply(json::Array(*Items));
-      });
+  Server.findDefinitions(Params.textDocument.uri.file(), Params.position,
+                         [](llvm::Expected<std::vector<Location>> Items) {
+                           if (!Items)
+                             return replyError(
+                                 ErrorCode::InvalidParams,
+                                 llvm::toString(Items.takeError()));
+                           reply(json::Array(*Items));
+                         });
 }
 
 void ClangdLSPServer::onSwitchSourceHeader(TextDocumentIdentifier &Params) {
@@ -602,3 +614,34 @@
     return *CachingCDB;
   return *CDB;
 }
+
+void ClangdLSPServer::onCancelRequest(CancelParams &Params) {
+  std::lock_guard<std::mutex> Lock(TaskHandlesMutex);
+  const auto &it = TaskHandles.find(Params.ID);
+  if (it != TaskHandles.end()) {
+    it->second.cancel();
+    TaskHandles.erase(it);
+  }
+}
+
+void ClangdLSPServer::CleanupTaskHandle() {
+  const json::Value *ID = GetRequestId();
+  if (!ID)
+    return;
+  const std::string &NormalizedID = NormalizeRequestID(*ID);
+  {
+    std::lock_guard<std::mutex> Lock(TaskHandlesMutex);
+    TaskHandles.erase(NormalizedID);
+  }
+}
+
+void ClangdLSPServer::StoreTaskHandle(TaskHandle TH) {
+  const json::Value *ID = GetRequestId();
+  if (!ID)
+    return;
+  const std::string &NormalizedID = NormalizeRequestID(*ID);
+  {
+    std::lock_guard<std::mutex> Lock(TaskHandlesMutex);
+    TaskHandles.insert({NormalizedID, std::move(TH)});
+  }
+}
Index: clangd/Cancellation.h
===================================================================
--- /dev/null
+++ clangd/Cancellation.h
@@ -0,0 +1,103 @@
+//===--- Cancellation.h -------------------------------------------*-C++-*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+// CancellationToken mechanism for async threads. The caller can generate a
+// TaskHandle for cancellable tasks, then bind that handle to current context
+// and check from every task that are running through that context.
+// Later on client can trigger cancel on that handle to tell the async task that
+// it has been cancelled. Example use case:
+//
+// void Caller() {
+//   // You should store this handle if you wanna cancel the task later on.
+//   TaskHandle TH = StartAsyncTask(Task);
+//   // To cancel the task:
+//   TH.cancel();
+// }
+//
+// TaskHandle StartAsyncTask(Task T) {
+//   // Make sure TaskHandler is created before starting the thread. Otherwise
+//   // CancellationToken might not get copied into thread.
+//   auto TH = TaskHandle::createCancellableTaskHandle();
+//   auto run = [TH](){
+//     WithContext ContextWithCancellationToken(std::move(TH));
+//     T();
+//   }
+//   // Start run() in a new thread.
+//   return TH;
+// }
+//
+// void Task() {
+//    // You can either store the read only token by calling hasCancelled once
+//    // and just use the variable everytime you want to check for cancellation,
+//    // or call hasCancelled everytime. The former is more efficient if you are
+//    // going to have multiple checks.
+//    const auto CT = CancellationHandler::hasCancelled();
+//    // DO SMTHNG...
+//    if(CT) {
+//      // Task has benn cancelled, lets get out.
+//      return;
+//    }
+//    // DO SOME MORE THING...
+// }
+
+#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H
+#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_CANCELLATION_H
+
+#include "Context.h"
+#include "llvm/Support/Error.h"
+#include <atomic>
+#include <memory>
+#include <system_error>
+
+namespace clang {
+namespace clangd {
+
+class CancellationToken {
+private:
+  std::shared_ptr<const std::atomic<bool>> Token;
+
+public:
+  bool isCancelled() const { return Token ? static_cast<bool>(*Token) : false; }
+  operator bool() const { return isCancelled(); }
+  CancellationToken(const std::shared_ptr<const std::atomic<bool>> Token)
+      : Token(Token) {}
+};
+
+class TaskHandle {
+public:
+  void cancel();
+  static TaskHandle createCancellableTaskHandle();
+  friend class CancellationHandler;
+
+private:
+  TaskHandle() : CT(std::make_shared<std::atomic<bool>>()) {}
+  std::shared_ptr<std::atomic<bool>> CT;
+};
+
+class CancellationHandler {
+public:
+  static CancellationToken isCancelled();
+  LLVM_NODISCARD static Context setCurrentCancellationToken(TaskHandle TH);
+};
+
+class TaskCancelledError : public llvm::ErrorInfo<TaskCancelledError> {
+public:
+  static char ID;
+
+  void log(llvm::raw_ostream &OS) const override {
+    OS << "Task got cancelled.";
+  }
+  std::error_code convertToErrorCode() const override {
+    return std::make_error_code(std::errc::operation_canceled);
+  }
+};
+
+} // namespace clangd
+} // namespace clang
+
+#endif
Index: clangd/Cancellation.cpp
===================================================================
--- /dev/null
+++ clangd/Cancellation.cpp
@@ -0,0 +1,38 @@
+//===--- Cancellation.cpp -----------------------------------------*-C++-*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Cancellation.h"
+#include <atomic>
+
+namespace clang {
+namespace clangd {
+
+namespace {
+static Key<std::shared_ptr<std::atomic<bool>>> CancellationTokenKey;
+} // namespace
+
+char TaskCancelledError::ID = 0;
+
+CancellationToken CancellationHandler::isCancelled() {
+  const auto *CT = Context::current().get(CancellationTokenKey);
+  if (!CT)
+    return CancellationToken(nullptr);
+  return CancellationToken(*CT);
+}
+
+Context CancellationHandler::setCurrentCancellationToken(TaskHandle TH) {
+  return Context::current().derive(CancellationTokenKey, std::move(TH.CT));
+}
+
+void TaskHandle::cancel() { *CT = true; }
+
+TaskHandle TaskHandle::createCancellableTaskHandle() { return TaskHandle(); }
+
+} // namespace clangd
+} // namespace clang
Index: clangd/CMakeLists.txt
===================================================================
--- clangd/CMakeLists.txt
+++ clangd/CMakeLists.txt
@@ -9,6 +9,7 @@
 
 add_clang_library(clangDaemon
   AST.cpp
+  Cancellation.cpp
   ClangdLSPServer.cpp
   ClangdServer.cpp
   ClangdUnit.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to