[clang-tools-extra] 3e6e6a2 - [clangd] Call hierarchy (XRefs layer, incoming calls)

Nathan Ridge via cfe-commits cfe-commits at lists.llvm.org
Mon Nov 23 17:44:31 PST 2020


Author: Nathan Ridge
Date: 2020-11-23T20:43:38-05:00
New Revision: 3e6e6a2db674cd85b33c06b75685c6bce5acb154

URL: https://github.com/llvm/llvm-project/commit/3e6e6a2db674cd85b33c06b75685c6bce5acb154
DIFF: https://github.com/llvm/llvm-project/commit/3e6e6a2db674cd85b33c06b75685c6bce5acb154.diff

LOG: [clangd] Call hierarchy (XRefs layer, incoming calls)

Support for outgoing calls is left for a future change.

Differential Revision: https://reviews.llvm.org/D91122

Added: 
    clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp

Modified: 
    clang-tools-extra/clangd/XRefs.cpp
    clang-tools-extra/clangd/XRefs.h
    clang-tools-extra/clangd/unittests/CMakeLists.txt
    clang-tools-extra/clangd/unittests/TestTU.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp
index 0cd8695da92d..e319636f9076 100644
--- a/clang-tools-extra/clangd/XRefs.cpp
+++ b/clang-tools-extra/clangd/XRefs.cpp
@@ -47,6 +47,7 @@
 #include "clang/Index/USRGeneration.h"
 #include "clang/Tooling/Syntax/Tokens.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/None.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -1339,9 +1340,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const LocatedSymbol &S) {
   return OS;
 }
 
-// FIXME(nridge): Reduce duplication between this function and declToSym().
-static llvm::Optional<TypeHierarchyItem>
-declToTypeHierarchyItem(ASTContext &Ctx, const NamedDecl &ND) {
+template <typename HierarchyItem>
+static llvm::Optional<HierarchyItem> declToHierarchyItem(const NamedDecl &ND) {
+  ASTContext &Ctx = ND.getASTContext();
   auto &SM = Ctx.getSourceManager();
   SourceLocation NameLoc = nameLocation(ND, Ctx.getSourceManager());
   SourceLocation BeginLoc = SM.getSpellingLoc(SM.getFileLoc(ND.getBeginLoc()));
@@ -1365,54 +1366,84 @@ declToTypeHierarchyItem(ASTContext &Ctx, const NamedDecl &ND) {
   // correctly.
   SymbolKind SK = indexSymbolKindToSymbolKind(SymInfo.Kind);
 
-  TypeHierarchyItem THI;
-  THI.name = printName(Ctx, ND);
-  THI.kind = SK;
-  THI.deprecated = ND.isDeprecated();
-  THI.range = Range{sourceLocToPosition(SM, DeclRange->getBegin()),
-                    sourceLocToPosition(SM, DeclRange->getEnd())};
-  THI.selectionRange = Range{NameBegin, NameEnd};
-  if (!THI.range.contains(THI.selectionRange)) {
+  HierarchyItem HI;
+  HI.name = printName(Ctx, ND);
+  HI.kind = SK;
+  HI.range = Range{sourceLocToPosition(SM, DeclRange->getBegin()),
+                   sourceLocToPosition(SM, DeclRange->getEnd())};
+  HI.selectionRange = Range{NameBegin, NameEnd};
+  if (!HI.range.contains(HI.selectionRange)) {
     // 'selectionRange' must be contained in 'range', so in cases where clang
     // reports unrelated ranges we need to reconcile somehow.
-    THI.range = THI.selectionRange;
+    HI.range = HI.selectionRange;
   }
 
-  THI.uri = URIForFile::canonicalize(*FilePath, *TUPath);
+  HI.uri = URIForFile::canonicalize(*FilePath, *TUPath);
 
   // Compute the SymbolID and store it in the 'data' field.
   // This allows typeHierarchy/resolve to be used to
   // resolve children of items returned in a previous request
   // for parents.
   if (auto ID = getSymbolID(&ND))
-    THI.data = ID.str();
+    HI.data = ID.str();
+
+  return HI;
+}
 
-  return THI;
+static llvm::Optional<TypeHierarchyItem>
+declToTypeHierarchyItem(const NamedDecl &ND) {
+  auto Result = declToHierarchyItem<TypeHierarchyItem>(ND);
+  if (Result)
+    Result->deprecated = ND.isDeprecated();
+  return Result;
 }
 
-static Optional<TypeHierarchyItem>
-symbolToTypeHierarchyItem(const Symbol &S, const SymbolIndex *Index,
-                          PathRef TUPath) {
+static llvm::Optional<CallHierarchyItem>
+declToCallHierarchyItem(const NamedDecl &ND) {
+  auto Result = declToHierarchyItem<CallHierarchyItem>(ND);
+  if (Result && ND.isDeprecated())
+    Result->tags.push_back(SymbolTag::Deprecated);
+  return Result;
+}
+
+template <typename HierarchyItem>
+static llvm::Optional<HierarchyItem> symbolToHierarchyItem(const Symbol &S,
+                                                           PathRef TUPath) {
   auto Loc = symbolToLocation(S, TUPath);
   if (!Loc) {
-    log("Type hierarchy: {0}", Loc.takeError());
+    elog("Failed to convert symbol to hierarchy item: {0}", Loc.takeError());
     return llvm::None;
   }
-  TypeHierarchyItem THI;
-  THI.name = std::string(S.Name);
-  THI.kind = indexSymbolKindToSymbolKind(S.SymInfo.Kind);
-  THI.deprecated = (S.Flags & Symbol::Deprecated);
-  THI.selectionRange = Loc->range;
+  HierarchyItem HI;
+  HI.name = std::string(S.Name);
+  HI.kind = indexSymbolKindToSymbolKind(S.SymInfo.Kind);
+  HI.selectionRange = Loc->range;
   // FIXME: Populate 'range' correctly
   // (https://github.com/clangd/clangd/issues/59).
-  THI.range = THI.selectionRange;
-  THI.uri = Loc->uri;
+  HI.range = HI.selectionRange;
+  HI.uri = Loc->uri;
   // Store the SymbolID in the 'data' field. The client will
-  // send this back in typeHierarchy/resolve, allowing us to
-  // continue resolving additional levels of the type hierarchy.
-  THI.data = S.ID.str();
+  // send this back in requests to resolve additional levels
+  // of the hierarchy.
+  HI.data = S.ID.str();
+
+  return HI;
+}
 
-  return std::move(THI);
+static llvm::Optional<TypeHierarchyItem>
+symbolToTypeHierarchyItem(const Symbol &S, PathRef TUPath) {
+  auto Result = symbolToHierarchyItem<TypeHierarchyItem>(S, TUPath);
+  if (Result)
+    Result->deprecated = (S.Flags & Symbol::Deprecated);
+  return Result;
+}
+
+static llvm::Optional<CallHierarchyItem>
+symbolToCallHierarchyItem(const Symbol &S, PathRef TUPath) {
+  auto Result = symbolToHierarchyItem<CallHierarchyItem>(S, TUPath);
+  if (Result && (S.Flags & Symbol::Deprecated))
+    Result->tags.push_back(SymbolTag::Deprecated);
+  return Result;
 }
 
 static void fillSubTypes(const SymbolID &ID,
@@ -1423,7 +1454,7 @@ static void fillSubTypes(const SymbolID &ID,
   Req.Predicate = RelationKind::BaseOf;
   Index->relations(Req, [&](const SymbolID &Subject, const Symbol &Object) {
     if (Optional<TypeHierarchyItem> ChildSym =
-            symbolToTypeHierarchyItem(Object, Index, TUPath)) {
+            symbolToTypeHierarchyItem(Object, TUPath)) {
       if (Levels > 1) {
         ChildSym->children.emplace();
         fillSubTypes(Object.ID, *ChildSym->children, Index, Levels - 1, TUPath);
@@ -1452,7 +1483,7 @@ static void fillSuperTypes(const CXXRecordDecl &CXXRD, ASTContext &ASTCtx,
 
   for (const CXXRecordDecl *ParentDecl : typeParents(&CXXRD)) {
     if (Optional<TypeHierarchyItem> ParentSym =
-            declToTypeHierarchyItem(ASTCtx, *ParentDecl)) {
+            declToTypeHierarchyItem(*ParentDecl)) {
       ParentSym->parents.emplace();
       fillSuperTypes(*ParentDecl, ASTCtx, *ParentSym->parents, RPSet);
       SuperTypes.emplace_back(std::move(*ParentSym));
@@ -1574,8 +1605,7 @@ getTypeHierarchy(ParsedAST &AST, Position Pos, int ResolveLevels,
       CXXRD = CTSD->getTemplateInstantiationPattern();
   }
 
-  Optional<TypeHierarchyItem> Result =
-      declToTypeHierarchyItem(AST.getASTContext(), *CXXRD);
+  Optional<TypeHierarchyItem> Result = declToTypeHierarchyItem(*CXXRD);
   if (!Result)
     return Result;
 
@@ -1617,6 +1647,78 @@ void resolveTypeHierarchy(TypeHierarchyItem &Item, int ResolveLevels,
   }
 }
 
+std::vector<CallHierarchyItem>
+prepareCallHierarchy(ParsedAST &AST, Position Pos, PathRef TUPath) {
+  std::vector<CallHierarchyItem> Result;
+  const auto &SM = AST.getSourceManager();
+  auto Loc = sourceLocationInMainFile(SM, Pos);
+  if (!Loc) {
+    elog("prepareCallHierarchy failed to convert position to source location: "
+         "{0}",
+         Loc.takeError());
+    return Result;
+  }
+  for (const NamedDecl *Decl : getDeclAtPosition(AST, *Loc, {})) {
+    if (!Decl->isFunctionOrFunctionTemplate())
+      continue;
+    if (auto CHI = declToCallHierarchyItem(*Decl))
+      Result.emplace_back(std::move(*CHI));
+  }
+  return Result;
+}
+
+std::vector<CallHierarchyIncomingCall>
+incomingCalls(const CallHierarchyItem &Item, const SymbolIndex *Index) {
+  std::vector<CallHierarchyIncomingCall> Results;
+  if (!Index || Item.data.empty())
+    return Results;
+  auto ID = SymbolID::fromStr(Item.data);
+  if (!ID) {
+    elog("incomingCalls failed to find symbol: {0}", ID.takeError());
+    return Results;
+  }
+  // In this function, we find incoming calls based on the index only.
+  // In principle, the AST could have more up-to-date information about
+  // occurrences within the current file. However, going from a SymbolID
+  // to an AST node isn't cheap, particularly when the declaration isn't
+  // in the main file.
+  // FIXME: Consider also using AST information when feasible.
+  RefsRequest Request;
+  Request.IDs.insert(*ID);
+  // We could restrict more specifically to calls by introducing a new RefKind,
+  // but non-call references (such as address-of-function) can still be
+  // interesting as they can indicate indirect calls.
+  Request.Filter = RefKind::Reference;
+  // Initially store the ranges in a map keyed by SymbolID of the caller.
+  // This allows us to group 
diff erent calls with the same caller
+  // into the same CallHierarchyIncomingCall.
+  llvm::DenseMap<SymbolID, std::vector<Range>> CallsIn;
+  // We can populate the ranges based on a refs request only. As we do so, we
+  // also accumulate the container IDs into a lookup request.
+  LookupRequest ContainerLookup;
+  Index->refs(Request, [&](const Ref &R) {
+    auto Loc = indexToLSPLocation(R.Location, Item.uri.file());
+    if (!Loc) {
+      elog("incomingCalls failed to convert location: {0}", Loc.takeError());
+      return;
+    }
+    auto It = CallsIn.try_emplace(R.Container, std::vector<Range>{}).first;
+    It->second.push_back(Loc->range);
+
+    ContainerLookup.IDs.insert(R.Container);
+  });
+  // Perform the lookup request and combine its results with CallsIn to
+  // get complete CallHierarchyIncomingCall objects.
+  Index->lookup(ContainerLookup, [&](const Symbol &Caller) {
+    auto It = CallsIn.find(Caller.ID);
+    assert(It != CallsIn.end());
+    if (auto CHI = symbolToCallHierarchyItem(Caller, Item.uri.file()))
+      Results.push_back(
+          CallHierarchyIncomingCall{std::move(*CHI), std::move(It->second)});
+  });
+  return Results;
+}
+
 llvm::DenseSet<const Decl *> getNonLocalDeclRefs(ParsedAST &AST,
                                                  const FunctionDecl *FD) {
   if (!FD->hasBody())

diff  --git a/clang-tools-extra/clangd/XRefs.h b/clang-tools-extra/clangd/XRefs.h
index fac1a992a12f..eca174f59096 100644
--- a/clang-tools-extra/clangd/XRefs.h
+++ b/clang-tools-extra/clangd/XRefs.h
@@ -110,6 +110,13 @@ void resolveTypeHierarchy(TypeHierarchyItem &Item, int ResolveLevels,
                           TypeHierarchyDirection Direction,
                           const SymbolIndex *Index);
 
+/// Get call hierarchy information at \p Pos.
+std::vector<CallHierarchyItem>
+prepareCallHierarchy(ParsedAST &AST, Position Pos, PathRef TUPath);
+
+std::vector<CallHierarchyIncomingCall>
+incomingCalls(const CallHierarchyItem &Item, const SymbolIndex *Index);
+
 /// Returns all decls that are referenced in the \p FD except local symbols.
 llvm::DenseSet<const Decl *> getNonLocalDeclRefs(ParsedAST &AST,
                                                  const FunctionDecl *FD);

diff  --git a/clang-tools-extra/clangd/unittests/CMakeLists.txt b/clang-tools-extra/clangd/unittests/CMakeLists.txt
index 5d87fff5c8af..e7baf880e504 100644
--- a/clang-tools-extra/clangd/unittests/CMakeLists.txt
+++ b/clang-tools-extra/clangd/unittests/CMakeLists.txt
@@ -36,6 +36,7 @@ add_unittest(ClangdUnitTests ClangdTests
   Annotations.cpp
   ASTTests.cpp
   BackgroundIndexTests.cpp
+  CallHierarchyTests.cpp
   CanonicalIncludesTests.cpp
   ClangdTests.cpp
   ClangdLSPServerTests.cpp

diff  --git a/clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp b/clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp
new file mode 100644
index 000000000000..ce192466b442
--- /dev/null
+++ b/clang-tools-extra/clangd/unittests/CallHierarchyTests.cpp
@@ -0,0 +1,256 @@
+//===-- CallHierarchyTests.cpp  ---------------------------*- C++ -*-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "Annotations.h"
+#include "Compiler.h"
+#include "Matchers.h"
+#include "ParsedAST.h"
+#include "SyncAPI.h"
+#include "TestFS.h"
+#include "TestTU.h"
+#include "TestWorkspace.h"
+#include "XRefs.h"
+#include "index/FileIndex.h"
+#include "index/SymbolCollector.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclTemplate.h"
+#include "clang/Index/IndexingAction.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/ScopedPrinter.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace clang {
+namespace clangd {
+namespace {
+
+using ::testing::AllOf;
+using ::testing::ElementsAre;
+using ::testing::Field;
+using ::testing::Matcher;
+using ::testing::UnorderedElementsAre;
+
+// Helpers for matching call hierarchy data structures.
+MATCHER_P(WithName, N, "") { return arg.name == N; }
+MATCHER_P(WithSelectionRange, R, "") { return arg.selectionRange == R; }
+
+template <class ItemMatcher>
+::testing::Matcher<CallHierarchyIncomingCall> From(ItemMatcher M) {
+  return Field(&CallHierarchyIncomingCall::from, M);
+}
+template <class... RangeMatchers>
+::testing::Matcher<CallHierarchyIncomingCall> FromRanges(RangeMatchers... M) {
+  return Field(&CallHierarchyIncomingCall::fromRanges,
+               UnorderedElementsAre(M...));
+}
+
+TEST(CallHierarchy, IncomingOneFile) {
+  Annotations Source(R"cpp(
+    void call^ee(int);
+    void caller1() {
+      $Callee[[callee]](42);
+    }
+    void caller2() {
+      $Caller1A[[caller1]]();
+      $Caller1B[[caller1]]();
+    }
+    void caller3() {
+      $Caller1C[[caller1]]();
+      $Caller2[[caller2]]();
+    }
+  )cpp");
+  TestTU TU = TestTU::withCode(Source.code());
+  auto AST = TU.build();
+  auto Index = TU.index();
+
+  std::vector<CallHierarchyItem> Items =
+      prepareCallHierarchy(AST, Source.point(), testPath(TU.Filename));
+  EXPECT_THAT(Items, ElementsAre(WithName("callee")));
+  auto IncomingLevel1 = incomingCalls(Items[0], Index.get());
+  EXPECT_THAT(IncomingLevel1,
+              ElementsAre(AllOf(From(WithName("caller1")),
+                                FromRanges(Source.range("Callee")))));
+
+  auto IncomingLevel2 = incomingCalls(IncomingLevel1[0].from, Index.get());
+  EXPECT_THAT(IncomingLevel2, UnorderedElementsAre(
+                                  AllOf(From(WithName("caller2")),
+                                        FromRanges(Source.range("Caller1A"),
+                                                   Source.range("Caller1B"))),
+                                  AllOf(From(WithName("caller3")),
+                                        FromRanges(Source.range("Caller1C")))));
+
+  auto IncomingLevel3 = incomingCalls(IncomingLevel2[0].from, Index.get());
+  EXPECT_THAT(IncomingLevel3,
+              ElementsAre(AllOf(From(WithName("caller3")),
+                                FromRanges(Source.range("Caller2")))));
+
+  auto IncomingLevel4 = incomingCalls(IncomingLevel3[0].from, Index.get());
+  EXPECT_THAT(IncomingLevel4, ElementsAre());
+}
+
+TEST(CallHierarchy, MainFileOnlyRef) {
+  // In addition to testing that we store refs to main-file only symbols,
+  // this tests that anonymous namespaces do not interfere with the
+  // symbol re-identification process in callHierarchyItemToSymbo().
+  Annotations Source(R"cpp(
+    void call^ee(int);
+    namespace {
+      void caller1() {
+        $Callee[[callee]](42);
+      }
+    }
+    void caller2() {
+      $Caller1[[caller1]]();
+    }
+  )cpp");
+  TestTU TU = TestTU::withCode(Source.code());
+  auto AST = TU.build();
+  auto Index = TU.index();
+
+  std::vector<CallHierarchyItem> Items =
+      prepareCallHierarchy(AST, Source.point(), testPath(TU.Filename));
+  EXPECT_THAT(Items, ElementsAre(WithName("callee")));
+  auto IncomingLevel1 = incomingCalls(Items[0], Index.get());
+  EXPECT_THAT(IncomingLevel1,
+              ElementsAre(AllOf(From(WithName("caller1")),
+                                FromRanges(Source.range("Callee")))));
+
+  auto IncomingLevel2 = incomingCalls(IncomingLevel1[0].from, Index.get());
+  EXPECT_THAT(IncomingLevel2,
+              UnorderedElementsAre(AllOf(From(WithName("caller2")),
+                                         FromRanges(Source.range("Caller1")))));
+}
+
+TEST(CallHierarchy, IncomingQualified) {
+  Annotations Source(R"cpp(
+    namespace ns {
+    struct Waldo {
+      void find();
+    };
+    void Waldo::find() {}
+    void caller1(Waldo &W) {
+      W.$Caller1[[f^ind]]();
+    }
+    void caller2(Waldo &W) {
+      W.$Caller2[[find]]();
+    }
+    }
+  )cpp");
+  TestTU TU = TestTU::withCode(Source.code());
+  auto AST = TU.build();
+  auto Index = TU.index();
+
+  std::vector<CallHierarchyItem> Items =
+      prepareCallHierarchy(AST, Source.point(), testPath(TU.Filename));
+  EXPECT_THAT(Items, ElementsAre(WithName("Waldo::find")));
+  auto Incoming = incomingCalls(Items[0], Index.get());
+  EXPECT_THAT(Incoming,
+              UnorderedElementsAre(AllOf(From(WithName("caller1")),
+                                         FromRanges(Source.range("Caller1"))),
+                                   AllOf(From(WithName("caller2")),
+                                         FromRanges(Source.range("Caller2")))));
+}
+
+TEST(CallHierarchy, IncomingMultiFile) {
+  // The test uses a .hh suffix for header files to get clang
+  // to parse them in C++ mode. .h files are parsed in C mode
+  // by default, which causes problems because e.g. symbol
+  // USRs are 
diff erent in C mode (do not include function signatures).
+
+  Annotations CalleeH(R"cpp(
+    void calle^e(int);
+  )cpp");
+  Annotations CalleeC(R"cpp(
+    #include "callee.hh"
+    void calle^e(int) {}
+  )cpp");
+  Annotations Caller1H(R"cpp(
+    void caller1();
+  )cpp");
+  Annotations Caller1C(R"cpp(
+    #include "callee.hh"
+    #include "caller1.hh"
+    void caller1() {
+      [[calle^e]](42);
+    }
+  )cpp");
+  Annotations Caller2H(R"cpp(
+    void caller2();
+  )cpp");
+  Annotations Caller2C(R"cpp(
+    #include "caller1.hh"
+    #include "caller2.hh"
+    void caller2() {
+      $A[[caller1]]();
+      $B[[caller1]]();
+    }
+  )cpp");
+  Annotations Caller3C(R"cpp(
+    #include "caller1.hh"
+    #include "caller2.hh"
+    void caller3() {
+      $Caller1[[caller1]]();
+      $Caller2[[caller2]]();
+    }
+  )cpp");
+
+  TestWorkspace Workspace;
+  Workspace.addSource("callee.hh", CalleeH.code());
+  Workspace.addSource("caller1.hh", Caller1H.code());
+  Workspace.addSource("caller2.hh", Caller2H.code());
+  Workspace.addMainFile("callee.cc", CalleeC.code());
+  Workspace.addMainFile("caller1.cc", Caller1C.code());
+  Workspace.addMainFile("caller2.cc", Caller2C.code());
+  Workspace.addMainFile("caller3.cc", Caller3C.code());
+
+  auto Index = Workspace.index();
+
+  auto CheckCallHierarchy = [&](ParsedAST &AST, Position Pos, PathRef TUPath) {
+    std::vector<CallHierarchyItem> Items =
+        prepareCallHierarchy(AST, Pos, TUPath);
+    EXPECT_THAT(Items, ElementsAre(WithName("callee")));
+    auto IncomingLevel1 = incomingCalls(Items[0], Index.get());
+    EXPECT_THAT(IncomingLevel1,
+                ElementsAre(AllOf(From(WithName("caller1")),
+                                  FromRanges(Caller1C.range()))));
+
+    auto IncomingLevel2 = incomingCalls(IncomingLevel1[0].from, Index.get());
+    EXPECT_THAT(IncomingLevel2,
+                UnorderedElementsAre(
+                    AllOf(From(WithName("caller2")),
+                          FromRanges(Caller2C.range("A"), Caller2C.range("B"))),
+                    AllOf(From(WithName("caller3")),
+                          FromRanges(Caller3C.range("Caller1")))));
+
+    auto IncomingLevel3 = incomingCalls(IncomingLevel2[0].from, Index.get());
+    EXPECT_THAT(IncomingLevel3,
+                ElementsAre(AllOf(From(WithName("caller3")),
+                                  FromRanges(Caller3C.range("Caller2")))));
+
+    auto IncomingLevel4 = incomingCalls(IncomingLevel3[0].from, Index.get());
+    EXPECT_THAT(IncomingLevel4, ElementsAre());
+  };
+
+  // Check that invoking from a call site works.
+  auto AST = Workspace.openFile("caller1.cc");
+  ASSERT_TRUE(bool(AST));
+  CheckCallHierarchy(*AST, Caller1C.point(), testPath("caller1.cc"));
+
+  // Check that invoking from the declaration site works.
+  AST = Workspace.openFile("callee.hh");
+  ASSERT_TRUE(bool(AST));
+  CheckCallHierarchy(*AST, CalleeH.point(), testPath("callee.hh"));
+
+  // Check that invoking from the definition site works.
+  AST = Workspace.openFile("callee.cc");
+  ASSERT_TRUE(bool(AST));
+  CheckCallHierarchy(*AST, CalleeC.point(), testPath("callee.cc"));
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang

diff  --git a/clang-tools-extra/clangd/unittests/TestTU.cpp b/clang-tools-extra/clangd/unittests/TestTU.cpp
index d0f011ef5649..ad0501c1d6a3 100644
--- a/clang-tools-extra/clangd/unittests/TestTU.cpp
+++ b/clang-tools-extra/clangd/unittests/TestTU.cpp
@@ -156,7 +156,8 @@ RefSlab TestTU::headerRefs() const {
 
 std::unique_ptr<SymbolIndex> TestTU::index() const {
   auto AST = build();
-  auto Idx = std::make_unique<FileIndex>(/*UseDex=*/true);
+  auto Idx = std::make_unique<FileIndex>(/*UseDex=*/true,
+                                         /*CollectMainFileRefs=*/true);
   Idx->updatePreamble(testPath(Filename), /*Version=*/"null",
                       AST.getASTContext(), AST.getPreprocessorPtr(),
                       AST.getCanonicalIncludes());


        


More information about the cfe-commits mailing list