qchateau created this revision.
qchateau added a reviewer: sammccall.
Herald added subscribers: usaxena95, kadircet, arphaman, mgrang.
qchateau requested review of this revision.
Herald added subscribers: cfe-commits, MaskRay, ilya-biryukov.
Herald added a project: clang.

The implementation is very close the the incoming
calls implementation. The results of the outgoing
calls are expected to be the exact symmetry of the
incoming calls.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D93829

Files:
  clang-tools-extra/clangd/ClangdLSPServer.cpp
  clang-tools-extra/clangd/ClangdServer.cpp
  clang-tools-extra/clangd/ClangdServer.h
  clang-tools-extra/clangd/XRefs.cpp
  clang-tools-extra/clangd/XRefs.h
  clang-tools-extra/clangd/index/Index.cpp
  clang-tools-extra/clangd/index/Index.h
  clang-tools-extra/clangd/index/MemIndex.cpp
  clang-tools-extra/clangd/index/MemIndex.h
  clang-tools-extra/clangd/index/Merge.cpp
  clang-tools-extra/clangd/index/Merge.h
  clang-tools-extra/clangd/index/ProjectAware.cpp
  clang-tools-extra/clangd/index/dex/Dex.cpp
  clang-tools-extra/clangd/index/dex/Dex.h
  clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp
  clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
  clang-tools-extra/clangd/unittests/RenameTests.cpp

Index: clang-tools-extra/clangd/unittests/RenameTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/RenameTests.cpp
+++ clang-tools-extra/clangd/unittests/RenameTests.cpp
@@ -1225,6 +1225,12 @@
       return true; // has more references
     }
 
+    bool refersTo(const RefsRequest &Req,
+                  llvm::function_ref<void(const RefersToResult &)> Callback)
+        const override {
+      return false;
+    }
+
     bool fuzzyFind(
         const FuzzyFindRequest &Req,
         llvm::function_ref<void(const Symbol &)> Callback) const override {
@@ -1281,6 +1287,12 @@
       return false;
     }
 
+    bool refersTo(const RefsRequest &Req,
+                  llvm::function_ref<void(const RefersToResult &)> Callback)
+        const override {
+      return false;
+    }
+
     bool fuzzyFind(const FuzzyFindRequest &,
                    llvm::function_ref<void(const Symbol &)>) const override {
       return false;
Index: clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
+++ clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
@@ -1345,6 +1345,12 @@
     return false;
   }
 
+  bool
+  refersTo(const RefsRequest &,
+           llvm::function_ref<void(const RefersToResult &)>) const override {
+    return false;
+  }
+
   void relations(const RelationsRequest &,
                  llvm::function_ref<void(const SymbolID &, const Symbol &)>)
       const override {}
Index: clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp
+++ clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp
@@ -37,17 +37,27 @@
 
 // Helpers for matching call hierarchy data structures.
 MATCHER_P(WithName, N, "") { return arg.name == N; }
+MATCHER_P(WithDetail, N, "") { return arg.detail == N; }
 MATCHER_P(WithSelectionRange, R, "") { return arg.selectionRange == R; }
 
 template <class ItemMatcher>
 ::testing::Matcher<CallHierarchyIncomingCall> From(ItemMatcher M) {
   return Field(&CallHierarchyIncomingCall::from, M);
 }
+template <class ItemMatcher>
+::testing::Matcher<CallHierarchyOutgoingCall> To(ItemMatcher M) {
+  return Field(&CallHierarchyOutgoingCall::to, M);
+}
 template <class... RangeMatchers>
-::testing::Matcher<CallHierarchyIncomingCall> FromRanges(RangeMatchers... M) {
+::testing::Matcher<CallHierarchyIncomingCall> IFromRanges(RangeMatchers... M) {
   return Field(&CallHierarchyIncomingCall::fromRanges,
                UnorderedElementsAre(M...));
 }
+template <class... RangeMatchers>
+::testing::Matcher<CallHierarchyOutgoingCall> OFromRanges(RangeMatchers... M) {
+  return Field(&CallHierarchyOutgoingCall::fromRanges,
+               UnorderedElementsAre(M...));
+}
 
 TEST(CallHierarchy, IncomingOneFile) {
   Annotations Source(R"cpp(
@@ -72,22 +82,25 @@
       prepareCallHierarchy(AST, Source.point(), testPath(TU.Filename));
   ASSERT_THAT(Items, ElementsAre(WithName("callee")));
   auto IncomingLevel1 = incomingCalls(Items[0], Index.get());
-  ASSERT_THAT(IncomingLevel1,
-              ElementsAre(AllOf(From(WithName("caller1")),
-                                FromRanges(Source.range("Callee")))));
+  ASSERT_THAT(
+      IncomingLevel1,
+      ElementsAre(AllOf(From(AllOf(WithName("caller1"), WithDetail("caller1"))),
+                        IFromRanges(Source.range("Callee")))));
 
   auto IncomingLevel2 = incomingCalls(IncomingLevel1[0].from, Index.get());
-  ASSERT_THAT(IncomingLevel2,
-              ElementsAre(AllOf(From(WithName("caller2")),
-                                FromRanges(Source.range("Caller1A"),
-                                           Source.range("Caller1B"))),
-                          AllOf(From(WithName("caller3")),
-                                FromRanges(Source.range("Caller1C")))));
+  ASSERT_THAT(
+      IncomingLevel2,
+      ElementsAre(AllOf(From(AllOf(WithName("caller2"), WithDetail("caller2"))),
+                        IFromRanges(Source.range("Caller1A"),
+                                    Source.range("Caller1B"))),
+                  AllOf(From(AllOf(WithName("caller3"), WithDetail("caller3"))),
+                        IFromRanges(Source.range("Caller1C")))));
 
   auto IncomingLevel3 = incomingCalls(IncomingLevel2[0].from, Index.get());
-  ASSERT_THAT(IncomingLevel3,
-              ElementsAre(AllOf(From(WithName("caller3")),
-                                FromRanges(Source.range("Caller2")))));
+  ASSERT_THAT(
+      IncomingLevel3,
+      ElementsAre(AllOf(From(AllOf(WithName("caller3"), WithDetail("caller3"))),
+                        IFromRanges(Source.range("Caller2")))));
 
   auto IncomingLevel4 = incomingCalls(IncomingLevel3[0].from, Index.get());
   EXPECT_THAT(IncomingLevel4, IsEmpty());
@@ -116,14 +129,16 @@
       prepareCallHierarchy(AST, Source.point(), testPath(TU.Filename));
   ASSERT_THAT(Items, ElementsAre(WithName("callee")));
   auto IncomingLevel1 = incomingCalls(Items[0], Index.get());
-  ASSERT_THAT(IncomingLevel1,
-              ElementsAre(AllOf(From(WithName("caller1")),
-                                FromRanges(Source.range("Callee")))));
+  ASSERT_THAT(
+      IncomingLevel1,
+      ElementsAre(AllOf(From(AllOf(WithName("caller1"), WithDetail("caller1"))),
+                        IFromRanges(Source.range("Callee")))));
 
   auto IncomingLevel2 = incomingCalls(IncomingLevel1[0].from, Index.get());
-  EXPECT_THAT(IncomingLevel2,
-              ElementsAre(AllOf(From(WithName("caller2")),
-                                FromRanges(Source.range("Caller1")))));
+  EXPECT_THAT(
+      IncomingLevel2,
+      ElementsAre(AllOf(From(AllOf(WithName("caller2"), WithDetail("caller2"))),
+                        IFromRanges(Source.range("Caller1")))));
 }
 
 TEST(CallHierarchy, IncomingQualified) {
@@ -149,14 +164,72 @@
       prepareCallHierarchy(AST, Source.point(), testPath(TU.Filename));
   ASSERT_THAT(Items, ElementsAre(WithName("Waldo::find")));
   auto Incoming = incomingCalls(Items[0], Index.get());
-  EXPECT_THAT(Incoming,
-              ElementsAre(AllOf(From(WithName("caller1")),
-                                FromRanges(Source.range("Caller1"))),
-                          AllOf(From(WithName("caller2")),
-                                FromRanges(Source.range("Caller2")))));
+  EXPECT_THAT(
+      Incoming,
+      ElementsAre(
+          AllOf(From(AllOf(WithName("caller1"), WithDetail("ns::caller1"))),
+                IFromRanges(Source.range("Caller1"))),
+          AllOf(From(AllOf(WithName("caller2"), WithDetail("ns::caller2"))),
+                IFromRanges(Source.range("Caller2")))));
+}
+
+TEST(CallHierarchy, OutgoingOneFile) {
+  // Test outgoing call on the main file, with namespaces and methods
+  Annotations Source(R"cpp(
+    void callee(int);
+    namespace ns {
+      struct Foo {
+        void caller1();
+      };
+      void Foo::caller1() {
+        $Callee[[callee]](42);
+      }
+    }
+    namespace {
+      void caller2(ns::Foo& F) {
+        F.$Caller1A[[caller1]]();
+        F.$Caller1B[[caller1]]();
+      }
+    }
+    void call^er3(ns::Foo& F) {
+      F.$Caller1C[[caller1]]();
+      $Caller2[[caller2]](F);
+    }
+  )cpp");
+  TestTU TU = TestTU::withCode(Source.code());
+  auto AST = TU.build();
+  auto Index = TU.index();
+
+  std::vector<CallHierarchyItem> Items =
+      prepareCallHierarchy(AST, Source.point(), testPath(TU.Filename));
+  ASSERT_THAT(Items, ElementsAre(WithName("caller3")));
+  auto OugoingLevel1 = outgoingCalls(Items[0], Index.get());
+  ASSERT_THAT(
+      OugoingLevel1,
+      ElementsAre(
+          AllOf(To(AllOf(WithName("caller1"), WithDetail("ns::Foo::caller1"))),
+                OFromRanges(Source.range("Caller1C"))),
+          AllOf(To(AllOf(WithName("caller2"), WithDetail("caller2"))),
+                OFromRanges(Source.range("Caller2")))));
+
+  auto OutgoingLevel2 = outgoingCalls(OugoingLevel1[1].to, Index.get());
+  ASSERT_THAT(
+      OutgoingLevel2,
+      ElementsAre(AllOf(
+          To(AllOf(WithName("caller1"), WithDetail("ns::Foo::caller1"))),
+          OFromRanges(Source.range("Caller1A"), Source.range("Caller1B")))));
+
+  auto OutgoingLevel3 = outgoingCalls(OutgoingLevel2[0].to, Index.get());
+  ASSERT_THAT(
+      OutgoingLevel3,
+      ElementsAre(AllOf(To(AllOf(WithName("callee"), WithDetail("callee"))),
+                        OFromRanges(Source.range("Callee")))));
+
+  auto OutgoingLevel4 = outgoingCalls(OutgoingLevel3[0].to, Index.get());
+  EXPECT_THAT(OutgoingLevel4, IsEmpty());
 }
 
-TEST(CallHierarchy, IncomingMultiFile) {
+TEST(CallHierarchy, MultiFile) {
   // The test uses a .hh suffix for header files to get clang
   // to parse them in C++ mode. .h files are parsed in C mode
   // by default, which causes problems because e.g. symbol
@@ -170,32 +243,47 @@
     void calle^e(int) {}
   )cpp");
   Annotations Caller1H(R"cpp(
-    void caller1();
+    namespace nsa {
+      void caller1();
+    }
   )cpp");
   Annotations Caller1C(R"cpp(
     #include "callee.hh"
     #include "caller1.hh"
-    void caller1() {
-      [[calle^e]](42);
+    namespace nsa {
+      void caller1() {
+        [[calle^e]](42);
+      }
     }
   )cpp");
   Annotations Caller2H(R"cpp(
-    void caller2();
+    namespace nsb {
+      void caller2();
+    }
   )cpp");
   Annotations Caller2C(R"cpp(
     #include "caller1.hh"
     #include "caller2.hh"
-    void caller2() {
-      $A[[caller1]]();
-      $B[[caller1]]();
+    namespace nsb {
+      void caller2() {
+        nsa::$A[[caller1]]();
+        nsa::$B[[caller1]]();
+      }
+    }
+  )cpp");
+  Annotations Caller3H(R"cpp(
+    namespace nsa {
+      void call^er3();
     }
   )cpp");
   Annotations Caller3C(R"cpp(
     #include "caller1.hh"
     #include "caller2.hh"
-    void caller3() {
-      $Caller1[[caller1]]();
-      $Caller2[[caller2]]();
+    namespace nsa {
+      void call^er3() {
+        $Caller1[[caller1]]();
+        nsb::$Caller2[[caller2]]();
+      }
     }
   )cpp");
 
@@ -203,6 +291,7 @@
   Workspace.addSource("callee.hh", CalleeH.code());
   Workspace.addSource("caller1.hh", Caller1H.code());
   Workspace.addSource("caller2.hh", Caller2H.code());
+  Workspace.addSource("caller3.hh", Caller3H.code());
   Workspace.addMainFile("callee.cc", CalleeC.code());
   Workspace.addMainFile("caller1.cc", Caller1C.code());
   Workspace.addMainFile("caller2.cc", Caller2C.code());
@@ -210,46 +299,84 @@
 
   auto Index = Workspace.index();
 
-  auto CheckCallHierarchy = [&](ParsedAST &AST, Position Pos, PathRef TUPath) {
+  auto CheckIncomingCalls = [&](ParsedAST &AST, Position Pos, PathRef TUPath) {
     std::vector<CallHierarchyItem> Items =
         prepareCallHierarchy(AST, Pos, TUPath);
     ASSERT_THAT(Items, ElementsAre(WithName("callee")));
     auto IncomingLevel1 = incomingCalls(Items[0], Index.get());
     ASSERT_THAT(IncomingLevel1,
-                ElementsAre(AllOf(From(WithName("caller1")),
-                                  FromRanges(Caller1C.range()))));
+                ElementsAre(AllOf(From(AllOf(WithName("caller1"),
+                                             WithDetail("nsa::caller1"))),
+                                  IFromRanges(Caller1C.range()))));
 
     auto IncomingLevel2 = incomingCalls(IncomingLevel1[0].from, Index.get());
     ASSERT_THAT(
         IncomingLevel2,
-        ElementsAre(AllOf(From(WithName("caller2")),
-                          FromRanges(Caller2C.range("A"), Caller2C.range("B"))),
-                    AllOf(From(WithName("caller3")),
-                          FromRanges(Caller3C.range("Caller1")))));
+        ElementsAre(
+            AllOf(From(AllOf(WithName("caller2"), WithDetail("nsb::caller2"))),
+                  IFromRanges(Caller2C.range("A"), Caller2C.range("B"))),
+            AllOf(From(AllOf(WithName("caller3"), WithDetail("nsa::caller3"))),
+                  IFromRanges(Caller3C.range("Caller1")))));
 
     auto IncomingLevel3 = incomingCalls(IncomingLevel2[0].from, Index.get());
     ASSERT_THAT(IncomingLevel3,
-                ElementsAre(AllOf(From(WithName("caller3")),
-                                  FromRanges(Caller3C.range("Caller2")))));
+                ElementsAre(AllOf(From(AllOf(WithName("caller3"),
+                                             WithDetail("nsa::caller3"))),
+                                  IFromRanges(Caller3C.range("Caller2")))));
 
     auto IncomingLevel4 = incomingCalls(IncomingLevel3[0].from, Index.get());
     EXPECT_THAT(IncomingLevel4, IsEmpty());
   };
 
+  auto CheckOutgoingCalls = [&](ParsedAST &AST, Position Pos, PathRef TUPath) {
+    std::vector<CallHierarchyItem> Items =
+        prepareCallHierarchy(AST, Pos, TUPath);
+    ASSERT_THAT(Items, ElementsAre(WithName("caller3")));
+    auto OutgoingLevel1 = outgoingCalls(Items[0], Index.get());
+    ASSERT_THAT(
+        OutgoingLevel1,
+        ElementsAre(
+            AllOf(To(AllOf(WithName("caller1"), WithDetail("nsa::caller1"))),
+                  OFromRanges(Caller3C.range("Caller1"))),
+            AllOf(To(AllOf(WithName("caller2"), WithDetail("nsb::caller2"))),
+                  OFromRanges(Caller3C.range("Caller2")))));
+
+    auto OutgoingLevel2 = outgoingCalls(OutgoingLevel1[1].to, Index.get());
+    ASSERT_THAT(OutgoingLevel2,
+                ElementsAre(AllOf(
+                    To(AllOf(WithName("caller1"), WithDetail("nsa::caller1"))),
+                    OFromRanges(Caller2C.range("A"), Caller2C.range("B")))));
+
+    auto OutgoingLevel3 = outgoingCalls(OutgoingLevel2[0].to, Index.get());
+    ASSERT_THAT(
+        OutgoingLevel3,
+        ElementsAre(AllOf(To(AllOf(WithName("callee"), WithDetail("callee"))),
+                          OFromRanges(Caller1C.range()))));
+
+    auto OutgoingLevel4 = outgoingCalls(OutgoingLevel3[0].to, Index.get());
+    EXPECT_THAT(OutgoingLevel4, IsEmpty());
+  };
+
   // Check that invoking from a call site works.
   auto AST = Workspace.openFile("caller1.cc");
   ASSERT_TRUE(bool(AST));
-  CheckCallHierarchy(*AST, Caller1C.point(), testPath("caller1.cc"));
+  CheckIncomingCalls(*AST, Caller1C.point(), testPath("caller1.cc"));
 
   // Check that invoking from the declaration site works.
   AST = Workspace.openFile("callee.hh");
   ASSERT_TRUE(bool(AST));
-  CheckCallHierarchy(*AST, CalleeH.point(), testPath("callee.hh"));
+  CheckIncomingCalls(*AST, CalleeH.point(), testPath("callee.hh"));
+  AST = Workspace.openFile("caller3.hh");
+  ASSERT_TRUE(bool(AST));
+  CheckOutgoingCalls(*AST, Caller3H.point(), testPath("caller3.hh"));
 
   // Check that invoking from the definition site works.
   AST = Workspace.openFile("callee.cc");
   ASSERT_TRUE(bool(AST));
-  CheckCallHierarchy(*AST, CalleeC.point(), testPath("callee.cc"));
+  CheckIncomingCalls(*AST, CalleeC.point(), testPath("callee.cc"));
+  AST = Workspace.openFile("caller3.cc");
+  ASSERT_TRUE(bool(AST));
+  CheckOutgoingCalls(*AST, Caller3C.point(), testPath("caller3.cc"));
 }
 
 } // namespace
Index: clang-tools-extra/clangd/index/dex/Dex.h
===================================================================
--- clang-tools-extra/clangd/index/dex/Dex.h
+++ clang-tools-extra/clangd/index/dex/Dex.h
@@ -90,6 +90,10 @@
   bool refs(const RefsRequest &Req,
             llvm::function_ref<void(const Ref &)> Callback) const override;
 
+  bool refersTo(
+      const RefsRequest &Req,
+      llvm::function_ref<void(const RefersToResult &)> Callback) const override;
+
   void relations(const RelationsRequest &Req,
                  llvm::function_ref<void(const SymbolID &, const Symbol &)>
                      Callback) const override;
@@ -121,6 +125,7 @@
   llvm::DenseMap<Token, PostingList> InvertedIndex;
   dex::Corpus Corpus;
   llvm::DenseMap<SymbolID, llvm::ArrayRef<Ref>> Refs;
+  llvm::DenseMap<SymbolID, std::vector<RefersToResult>> RevRefs;
   static_assert(sizeof(RelationKind) == sizeof(uint8_t),
                 "RelationKind should be of same size as a uint8_t");
   llvm::DenseMap<std::pair<SymbolID, uint8_t>, std::vector<SymbolID>> Relations;
Index: clang-tools-extra/clangd/index/dex/Dex.cpp
===================================================================
--- clang-tools-extra/clangd/index/dex/Dex.cpp
+++ clang-tools-extra/clangd/index/dex/Dex.cpp
@@ -123,6 +123,17 @@
   for (DocID SymbolRank = 0; SymbolRank < Symbols.size(); ++SymbolRank)
     Builder.add(*Symbols[SymbolRank], SymbolRank);
   InvertedIndex = Builder.build();
+
+  // Build RevRefs
+  for (const auto &Pair : Refs) {
+    for (const auto &R : Pair.second) {
+      auto It = RevRefs.try_emplace(R.Container).first;
+      It->second.push_back({
+          R.Location, R.Kind,
+          Pair.first, // FIXME: this is teh referee, not the container
+      });
+    }
+  }
 }
 
 std::unique_ptr<Iterator> Dex::iterator(const Token &Tok) const {
@@ -291,6 +302,24 @@
   return false; // We reported all refs.
 }
 
+bool Dex::refersTo(
+    const RefsRequest &Req,
+    llvm::function_ref<void(const RefersToResult &)> Callback) const {
+  trace::Span Tracer("Dex reversed refs");
+  uint32_t Remaining =
+      Req.Limit.getValueOr(std::numeric_limits<uint32_t>::max());
+  for (const auto &ID : Req.IDs)
+    for (const auto &Ref : RevRefs.lookup(ID)) {
+      if (!static_cast<int>(Req.Filter & Ref.Kind))
+        continue;
+      if (Remaining == 0)
+        return true; // More refs were available.
+      --Remaining;
+      Callback(Ref);
+    }
+  return false; // We reported all refs.
+}
+
 void Dex::relations(
     const RelationsRequest &Req,
     llvm::function_ref<void(const SymbolID &, const Symbol &)> Callback) const {
@@ -332,6 +361,7 @@
   for (const auto &TokenToPostingList : InvertedIndex)
     Bytes += TokenToPostingList.second.bytes();
   Bytes += Refs.getMemorySize();
+  Bytes += RevRefs.getMemorySize();
   Bytes += Relations.getMemorySize();
   return Bytes + BackingDataSize;
 }
Index: clang-tools-extra/clangd/index/ProjectAware.cpp
===================================================================
--- clang-tools-extra/clangd/index/ProjectAware.cpp
+++ clang-tools-extra/clangd/index/ProjectAware.cpp
@@ -42,6 +42,10 @@
   /// Query all indexes while prioritizing the associated one (if any).
   bool refs(const RefsRequest &Req,
             llvm::function_ref<void(const Ref &)> Callback) const override;
+  /// Query all indexes while prioritizing the associated one (if any).
+  bool refersTo(
+      const RefsRequest &Req,
+      llvm::function_ref<void(const RefersToResult &)> Callback) const override;
 
   /// Queries only the associates index when Req.RestrictForCodeCompletion is
   /// set, otherwise queries all.
@@ -98,6 +102,15 @@
   return false;
 }
 
+bool ProjectAwareIndex::refersTo(
+    const RefsRequest &Req,
+    llvm::function_ref<void(const RefersToResult &)> Callback) const {
+  trace::Span Tracer("ProjectAwareIndex::refersTo");
+  if (auto *Idx = getIndex())
+    return Idx->refersTo(Req, Callback);
+  return false;
+}
+
 bool ProjectAwareIndex::fuzzyFind(
     const FuzzyFindRequest &Req,
     llvm::function_ref<void(const Symbol &)> Callback) const {
Index: clang-tools-extra/clangd/index/Merge.h
===================================================================
--- clang-tools-extra/clangd/index/Merge.h
+++ clang-tools-extra/clangd/index/Merge.h
@@ -42,6 +42,9 @@
               llvm::function_ref<void(const Symbol &)>) const override;
   bool refs(const RefsRequest &,
             llvm::function_ref<void(const Ref &)>) const override;
+  bool
+  refersTo(const RefsRequest &,
+           llvm::function_ref<void(const RefersToResult &)>) const override;
   void relations(const RelationsRequest &,
                  llvm::function_ref<void(const SymbolID &, const Symbol &)>)
       const override;
Index: clang-tools-extra/clangd/index/Merge.cpp
===================================================================
--- clang-tools-extra/clangd/index/Merge.cpp
+++ clang-tools-extra/clangd/index/Merge.cpp
@@ -128,6 +128,40 @@
   return More || StaticHadMore;
 }
 
+bool MergedIndex::refersTo(
+    const RefsRequest &Req,
+    llvm::function_ref<void(const RefersToResult &)> Callback) const {
+  trace::Span Tracer("MergedIndex refersTo");
+  bool More = false;
+  uint32_t Remaining =
+      Req.Limit.getValueOr(std::numeric_limits<uint32_t>::max());
+  // We don't want duplicated refs from the static/dynamic indexes,
+  // and we can't reliably deduplicate them because offsets may differ slightly.
+  // We consider the dynamic index authoritative and report all its refs,
+  // and only report static index refs from other files.
+  More |= Dynamic->refersTo(Req, [&](const auto &O) {
+    Callback(O);
+    assert(Remaining != 0);
+    --Remaining;
+  });
+  if (Remaining == 0 && More)
+    return More;
+  auto DynamicContainsFile = Dynamic->indexedFiles();
+  // We return less than Req.Limit if static index returns more refs for dirty
+  // files.
+  bool StaticHadMore = Static->refersTo(Req, [&](const auto &O) {
+    if (DynamicContainsFile(O.Location.FileURI))
+      return; // ignore refs that have been seen from dynamic index.
+    if (Remaining == 0) {
+      More = true;
+      return;
+    }
+    --Remaining;
+    Callback(O);
+  });
+  return More || StaticHadMore;
+}
+
 llvm::unique_function<bool(llvm::StringRef) const>
 MergedIndex::indexedFiles() const {
   return [DynamicContainsFile{Dynamic->indexedFiles()},
Index: clang-tools-extra/clangd/index/MemIndex.h
===================================================================
--- clang-tools-extra/clangd/index/MemIndex.h
+++ clang-tools-extra/clangd/index/MemIndex.h
@@ -70,6 +70,10 @@
   bool refs(const RefsRequest &Req,
             llvm::function_ref<void(const Ref &)> Callback) const override;
 
+  bool refersTo(
+      const RefsRequest &Req,
+      llvm::function_ref<void(const RefersToResult &)> Callback) const override;
+
   void relations(const RelationsRequest &Req,
                  llvm::function_ref<void(const SymbolID &, const Symbol &)>
                      Callback) const override;
Index: clang-tools-extra/clangd/index/MemIndex.cpp
===================================================================
--- clang-tools-extra/clangd/index/MemIndex.cpp
+++ clang-tools-extra/clangd/index/MemIndex.cpp
@@ -88,6 +88,26 @@
   return false; // We reported all refs.
 }
 
+bool MemIndex::refersTo(
+    const RefsRequest &Req,
+    llvm::function_ref<void(const RefersToResult &)> Callback) const {
+  trace::Span Tracer("MemIndex refersTo");
+  uint32_t Remaining =
+      Req.Limit.getValueOr(std::numeric_limits<uint32_t>::max());
+  for (const auto &Pair : Refs) {
+    for (const auto &R : Pair.second) {
+      if (!static_cast<int>(Req.Filter & R.Kind) ||
+          !Req.IDs.contains(R.Container))
+        continue;
+      if (Remaining == 0)
+        return true; // More refs were available.
+      --Remaining;
+      Callback({R.Location, R.Kind, Pair.first});
+    }
+  }
+  return false; // We reported all refs.
+}
+
 void MemIndex::relations(
     const RelationsRequest &Req,
     llvm::function_ref<void(const SymbolID &, const Symbol &)> Callback) const {
Index: clang-tools-extra/clangd/index/Index.h
===================================================================
--- clang-tools-extra/clangd/index/Index.h
+++ clang-tools-extra/clangd/index/Index.h
@@ -82,6 +82,14 @@
   llvm::Optional<uint32_t> Limit;
 };
 
+struct RefersToResult {
+  /// The source location where the symbol is named.
+  SymbolLocation Location;
+  RefKind Kind = RefKind::Unknown;
+  /// The ID of the symbol which is referred to
+  SymbolID Symbol;
+};
+
 /// Interface for symbol indexes that can be used for searching or
 /// matching symbols among a set of symbols based on names or unique IDs.
 class SymbolIndex {
@@ -114,6 +122,17 @@
   virtual bool refs(const RefsRequest &Req,
                     llvm::function_ref<void(const Ref &)> Callback) const = 0;
 
+  /// Find all symbols that are referenced by a symbol and apply
+  /// \p Callback on each result.
+  ///
+  /// Results should be returned in arbitrary order.
+  /// The returned result must be deep-copied if it's used outside Callback.
+  ///
+  /// Returns true if there will be more results (limited by Req.Limit);
+  virtual bool
+  refersTo(const RefsRequest &Req,
+           llvm::function_ref<void(const RefersToResult &)> Callback) const = 0;
+
   /// Finds all relations (S, P, O) stored in the index such that S is among
   /// Req.Subjects and P is Req.Predicate, and invokes \p Callback for (S, O) in
   /// each.
@@ -147,6 +166,9 @@
               llvm::function_ref<void(const Symbol &)>) const override;
   bool refs(const RefsRequest &,
             llvm::function_ref<void(const Ref &)>) const override;
+  bool
+  refersTo(const RefsRequest &,
+           llvm::function_ref<void(const RefersToResult &)>) const override;
   void relations(const RelationsRequest &,
                  llvm::function_ref<void(const SymbolID &, const Symbol &)>)
       const override;
Index: clang-tools-extra/clangd/index/Index.cpp
===================================================================
--- clang-tools-extra/clangd/index/Index.cpp
+++ clang-tools-extra/clangd/index/Index.cpp
@@ -70,6 +70,11 @@
                      llvm::function_ref<void(const Ref &)> CB) const {
   return snapshot()->refs(R, CB);
 }
+bool SwapIndex::refersTo(
+    const RefsRequest &R,
+    llvm::function_ref<void(const RefersToResult &)> CB) const {
+  return snapshot()->refersTo(R, CB);
+}
 void SwapIndex::relations(
     const RelationsRequest &R,
     llvm::function_ref<void(const SymbolID &, const Symbol &)> CB) const {
Index: clang-tools-extra/clangd/XRefs.h
===================================================================
--- clang-tools-extra/clangd/XRefs.h
+++ clang-tools-extra/clangd/XRefs.h
@@ -117,6 +117,9 @@
 std::vector<CallHierarchyIncomingCall>
 incomingCalls(const CallHierarchyItem &Item, const SymbolIndex *Index);
 
+std::vector<CallHierarchyOutgoingCall>
+outgoingCalls(const CallHierarchyItem &Item, const SymbolIndex *Index);
+
 /// Returns all decls that are referenced in the \p FD except local symbols.
 llvm::DenseSet<const Decl *> getNonLocalDeclRefs(ParsedAST &AST,
                                                  const FunctionDecl *FD);
Index: clang-tools-extra/clangd/XRefs.cpp
===================================================================
--- clang-tools-extra/clangd/XRefs.cpp
+++ clang-tools-extra/clangd/XRefs.cpp
@@ -1486,6 +1486,7 @@
   }
   HierarchyItem HI;
   HI.name = std::string(S.Name);
+  HI.detail = (S.Scope + S.Name).str();
   HI.kind = indexSymbolKindToSymbolKind(S.SymInfo.Kind);
   HI.selectionRange = Loc->range;
   // FIXME: Populate 'range' correctly
@@ -1798,6 +1799,67 @@
   return Results;
 }
 
+std::vector<CallHierarchyOutgoingCall>
+outgoingCalls(const CallHierarchyItem &Item, const SymbolIndex *Index) {
+  std::vector<CallHierarchyOutgoingCall> Results;
+  if (!Index || Item.data.empty())
+    return Results;
+  auto ID = SymbolID::fromStr(Item.data);
+  if (!ID) {
+    elog("outgoingCalls failed to find symbol: {0}", ID.takeError());
+    return Results;
+  }
+  // In this function, we find outgoing calls based on the index only.
+  RefsRequest Request;
+  Request.IDs.insert(*ID);
+  // We could restrict more specifically to calls by introducing a new RefKind,
+  // but non-call references (such as address-of-function) can still be
+  // interesting as they can indicate indirect calls.
+  Request.Filter = RefKind::Reference;
+  // Initially store the ranges in a map keyed by SymbolID of the callee.
+  // This allows us to group different calls to the same function
+  // into the same CallHierarchyOutgoingCall.
+  llvm::DenseMap<SymbolID, std::vector<Range>> CallsOut;
+  // We can populate the ranges based on a refs request only. As we do so, we
+  // also accumulate the callee IDs into a lookup request.
+  LookupRequest CallsOutLookup;
+  Index->refersTo(Request, [&](const auto &R) {
+    auto Loc = indexToLSPLocation(R.Location, Item.uri.file());
+    if (!Loc) {
+      elog("outgoingCalls failed to convert location: {0}", Loc.takeError());
+      return;
+    }
+    auto It = CallsOut.try_emplace(R.Symbol, std::vector<Range>{}).first;
+    It->second.push_back(Loc->range);
+
+    CallsOutLookup.IDs.insert(R.Symbol);
+  });
+  // Perform the lookup request and combine its results with CallsOut to
+  // get complete CallHierarchyOutgoingCall objects.
+  Index->lookup(CallsOutLookup, [&](const Symbol &Callee) {
+    // Filter references to only keep function calls
+    using SK = index::SymbolKind;
+    auto Kind = Callee.SymInfo.Kind;
+    if (Kind != SK::Function && Kind != SK::InstanceMethod &&
+        Kind != SK::ClassMethod && Kind != SK::StaticMethod &&
+        Kind != SK::Constructor && Kind != SK::Destructor &&
+        Kind != SK::ConversionFunction)
+      return;
+
+    auto It = CallsOut.find(Callee.ID);
+    assert(It != CallsOut.end());
+    if (auto CHI = symbolToCallHierarchyItem(Callee, Item.uri.file()))
+      Results.push_back(
+          CallHierarchyOutgoingCall{std::move(*CHI), std::move(It->second)});
+  });
+  // Sort results by name of the callee.
+  llvm::sort(Results, [](const CallHierarchyOutgoingCall &A,
+                         const CallHierarchyOutgoingCall &B) {
+    return A.to.name < B.to.name;
+  });
+  return Results;
+}
+
 llvm::DenseSet<const Decl *> getNonLocalDeclRefs(ParsedAST &AST,
                                                  const FunctionDecl *FD) {
   if (!FD->hasBody())
@@ -1812,5 +1874,6 @@
   });
   return DeclRefs;
 }
+
 } // namespace clangd
 } // namespace clang
Index: clang-tools-extra/clangd/ClangdServer.h
===================================================================
--- clang-tools-extra/clangd/ClangdServer.h
+++ clang-tools-extra/clangd/ClangdServer.h
@@ -240,6 +240,10 @@
   void incomingCalls(const CallHierarchyItem &Item,
                      Callback<std::vector<CallHierarchyIncomingCall>>);
 
+  /// Resolve outgoing calls for a given call hierarchy item.
+  void outgoingCalls(const CallHierarchyItem &Item,
+                     Callback<std::vector<CallHierarchyOutgoingCall>>);
+
   /// Retrieve the top symbols from the workspace matching a query.
   void workspaceSymbols(StringRef Query, int Limit,
                         Callback<std::vector<SymbolInformation>> CB);
Index: clang-tools-extra/clangd/ClangdServer.cpp
===================================================================
--- clang-tools-extra/clangd/ClangdServer.cpp
+++ clang-tools-extra/clangd/ClangdServer.cpp
@@ -654,6 +654,15 @@
                     });
 }
 
+void ClangdServer::outgoingCalls(
+    const CallHierarchyItem &Item,
+    Callback<std::vector<CallHierarchyOutgoingCall>> CB) {
+  WorkScheduler.run("Outgoing Calls", "",
+                    [CB = std::move(CB), Item, this]() mutable {
+                      CB(clangd::outgoingCalls(Item, Index));
+                    });
+}
+
 void ClangdServer::onFileEvent(const DidChangeWatchedFilesParams &Params) {
   // FIXME: Do nothing for now. This will be used for indexing and potentially
   // invalidating other caches.
Index: clang-tools-extra/clangd/ClangdLSPServer.cpp
===================================================================
--- clang-tools-extra/clangd/ClangdLSPServer.cpp
+++ clang-tools-extra/clangd/ClangdLSPServer.cpp
@@ -1249,8 +1249,7 @@
 void ClangdLSPServer::onCallHierarchyOutgoingCalls(
     const CallHierarchyOutgoingCallsParams &Params,
     Callback<std::vector<CallHierarchyOutgoingCall>> Reply) {
-  // FIXME: To be implemented.
-  Reply(std::vector<CallHierarchyOutgoingCall>{});
+  Server->outgoingCalls(Params.item, std::move(Reply));
 }
 
 void ClangdLSPServer::applyConfiguration(
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to