[clang-tools-extra] [WIP][clangd] Resolve the dependent type from its single instantiation. Take 1 (PR #71279)

Younan Zhang via cfe-commits cfe-commits at lists.llvm.org
Sun Dec 24 02:50:40 PST 2023


https://github.com/zyn0217 updated https://github.com/llvm/llvm-project/pull/71279

>From c0703d7d9549e82434b37f9d5658b566a480d752 Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Sat, 4 Nov 2023 18:43:58 +0800
Subject: [PATCH 1/4] [clangd] Resolve the dependent type from its single
 instantiation. Take 1

This is an enhancement to the HeuristicResolver, trying to extract
the deduced type from the single instantiation for a template. This
partially addresses the point #1 from
https://github.com/clangd/clangd/issues/1768.

This patch doesn't tackle CXXUnresolvedConstructExpr or similarities
since I feel that is more arduous and would prefer to leave it for
my future work.
---
 .../clangd/HeuristicResolver.cpp              | 101 ++++++++++++++++++
 .../clangd/unittests/XRefsTests.cpp           |  48 +++++++++
 2 files changed, 149 insertions(+)

diff --git a/clang-tools-extra/clangd/HeuristicResolver.cpp b/clang-tools-extra/clangd/HeuristicResolver.cpp
index 3c147b6b582bf0..d3dced9b325367 100644
--- a/clang-tools-extra/clangd/HeuristicResolver.cpp
+++ b/clang-tools-extra/clangd/HeuristicResolver.cpp
@@ -7,10 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "HeuristicResolver.h"
+#include "AST.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/CXXInheritance.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclTemplate.h"
 #include "clang/AST/ExprCXX.h"
+#include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/Type.h"
 
 namespace clang {
@@ -46,6 +50,98 @@ const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
   return nullptr;
 }
 
+// Visitor that helps to extract deduced type from instantiated entities.
+// This merely performs the source location comparison against each Decl
+// until it finds a Decl with the same location as the
+// dependent one. Its associated type will then be extracted.
+struct InstantiatedDeclVisitor : RecursiveASTVisitor<InstantiatedDeclVisitor> {
+
+  InstantiatedDeclVisitor(NamedDecl *DependentDecl) : DependentDecl(DependentDecl) {}
+
+  bool shouldVisitTemplateInstantiations() const { return true; }
+
+  bool shouldVisitLambdaBody() const { return true; }
+
+  bool shouldVisitImplicitCode() const { return true; }
+
+  template <typename D>
+  bool onDeclVisited(D *MaybeInstantiated) {
+    if (MaybeInstantiated->getDeclContext()->isDependentContext())
+      return true;
+    auto *Dependent = dyn_cast<D>(DependentDecl);
+    if (!Dependent)
+      return true;
+    auto LHS = MaybeInstantiated->getTypeSourceInfo(),
+         RHS = Dependent->getTypeSourceInfo();
+    if (!LHS || !RHS)
+      return true;
+    if (LHS->getTypeLoc().getSourceRange() !=
+        RHS->getTypeLoc().getSourceRange())
+      return true;
+    DeducedType = MaybeInstantiated->getType();
+    return false;
+  }
+
+  bool VisitFieldDecl(FieldDecl *FD) {
+    return onDeclVisited(FD);
+  }
+
+  bool VisitVarDecl(VarDecl *VD) {
+    return onDeclVisited(VD);
+  }
+
+  NamedDecl *DependentDecl;
+  QualType DeducedType;
+};
+
+/// Attempt to resolve the dependent type from the surrounding context for which
+/// a single instantiation is available.
+const Type *
+resolveTypeFromInstantiatedTemplate(const CXXDependentScopeMemberExpr *Expr) {
+  if (Expr->isImplicitAccess())
+    return nullptr;
+
+  auto *Base = Expr->getBase();
+  NamedDecl *ND = nullptr;
+  if (auto *CXXMember = dyn_cast<MemberExpr>(Base))
+    ND = CXXMember->getMemberDecl();
+
+  if (auto *DRExpr = dyn_cast<DeclRefExpr>(Base))
+    ND = DRExpr->getFoundDecl();
+
+  // FIXME: Handle CXXUnresolvedConstructExpr. This kind of type doesn't have
+  // available Decls to be matched against. Which inhibits the current heuristic
+  // from resolving expressions such as `T().fo^o()`, where T is a
+  // single-instantiated template parameter.
+  if (!ND)
+    return nullptr;
+
+  NamedDecl *Instantiation = nullptr;
+
+  // Find out a single instantiation that we can start with. The enclosing
+  // context for the current Decl might not be a templated entity (e.g. a member
+  // function inside a class template), hence we shall walk up the decl
+  // contexts first.
+  for (auto *EnclosingContext = ND->getDeclContext(); EnclosingContext;
+       EnclosingContext = EnclosingContext->getParent()) {
+    if (auto *ND = dyn_cast<NamedDecl>(EnclosingContext)) {
+      Instantiation = getOnlyInstantiation(ND);
+      if (Instantiation)
+        break;
+    }
+  }
+
+  if (!Instantiation)
+    return nullptr;
+
+  // This will traverse down the instantiation entity, visit each Decl, and
+  // extract the deduced type for the undetermined Decl `ND`.
+  InstantiatedDeclVisitor Visitor(ND);
+  Visitor.TraverseDecl(Instantiation);
+
+  return Visitor.DeducedType.getTypePtrOrNull();
+}
+
 } // namespace
 
 // Helper function for HeuristicResolver::resolveDependentMember()
@@ -150,6 +246,11 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
   if (ME->isArrow()) {
     BaseType = getPointeeType(BaseType);
   }
+
+  if (BaseType->isDependentType())
+    if (auto *MaybeResolved = resolveTypeFromInstantiatedTemplate(ME))
+      BaseType = MaybeResolved;
+
   if (!BaseType)
     return {};
   if (const auto *BT = BaseType->getAs<BuiltinType>()) {
diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp
index f53cbf01b7992c..ead24dec575de0 100644
--- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp
+++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp
@@ -1222,6 +1222,54 @@ TEST(LocateSymbol, TextualSmoke) {
                         hasID(getSymbolID(&findDecl(AST, "MyClass"))))));
 }
 
+TEST(LocateSymbol, DeduceDependentTypeFromSingleInstantiation) {
+  Annotations T(R"cpp(
+    struct WildCat {
+      void $wild_meow[[meow]]();
+    };
+
+    struct DomesticCat {
+      void $domestic_meow[[meow]]();
+    };
+
+    template <typename Ours>
+    struct Human {
+      template <typename Others>
+      void feed(Others O) {
+        O.me$1^ow();
+        Others Child;
+        Child.me$2^ow();
+        // FIXME: Others().me^ow();
+        Ours Baby;
+        Baby.me$3^ow();
+        // struct Inner {
+        //   Ours Pet;
+        // };
+        // Inner().Pet.me^ow();
+        auto Lambda = [](auto C) {
+          C.me$4^ow();
+        };
+        Lambda(Others());
+      }
+    };
+
+    void foo() {
+      Human<DomesticCat>().feed(WildCat());
+    }
+  )cpp");
+
+  auto TU = TestTU::withCode(T.code());
+  auto AST = TU.build();
+  EXPECT_THAT(locateSymbolAt(AST, T.point("1")),
+              ElementsAre(sym("meow", T.range("wild_meow"), std::nullopt)));
+  EXPECT_THAT(locateSymbolAt(AST, T.point("2")),
+              ElementsAre(sym("meow", T.range("wild_meow"), std::nullopt)));
+  EXPECT_THAT(locateSymbolAt(AST, T.point("3")),
+              ElementsAre(sym("meow", T.range("domestic_meow"), std::nullopt)));
+  EXPECT_THAT(locateSymbolAt(AST, T.point("4")),
+              ElementsAre(sym("meow", T.range("wild_meow"), std::nullopt)));
+}
+
 TEST(LocateSymbol, Textual) {
   const char *Tests[] = {
       R"cpp(// Comment

>From 2f8f89f599565d0fc3e4d6bd15d85791dec2a5f5 Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Sat, 11 Nov 2023 11:46:58 +0800
Subject: [PATCH 2/4] Don't resolve for null base type

---
 clang-tools-extra/clangd/HeuristicResolver.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/clang-tools-extra/clangd/HeuristicResolver.cpp b/clang-tools-extra/clangd/HeuristicResolver.cpp
index d3dced9b325367..a3a9bcd60ad256 100644
--- a/clang-tools-extra/clangd/HeuristicResolver.cpp
+++ b/clang-tools-extra/clangd/HeuristicResolver.cpp
@@ -247,12 +247,13 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
     BaseType = getPointeeType(BaseType);
   }
 
+  if (!BaseType)
+    return {};
+
   if (BaseType->isDependentType())
     if (auto *MaybeResolved = resolveTypeFromInstantiatedTemplate(ME))
       BaseType = MaybeResolved;
 
-  if (!BaseType)
-    return {};
   if (const auto *BT = BaseType->getAs<BuiltinType>()) {
     // If BaseType is the type of a dependent expression, it's just
     // represented as BuiltinType::Dependent which gives us no information. We

>From 4adc600b60538e6743cbcf45141f421bf458f607 Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Sat, 11 Nov 2023 12:07:58 +0800
Subject: [PATCH 3/4] Format

---
 clang-tools-extra/clangd/HeuristicResolver.cpp | 14 +++++---------
 1 file changed, 5 insertions(+), 9 deletions(-)

diff --git a/clang-tools-extra/clangd/HeuristicResolver.cpp b/clang-tools-extra/clangd/HeuristicResolver.cpp
index a3a9bcd60ad256..90de811b10330d 100644
--- a/clang-tools-extra/clangd/HeuristicResolver.cpp
+++ b/clang-tools-extra/clangd/HeuristicResolver.cpp
@@ -56,7 +56,8 @@ const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
 // dependent one. Its associated type will then be extracted.
 struct InstantiatedDeclVisitor : RecursiveASTVisitor<InstantiatedDeclVisitor> {
 
-  InstantiatedDeclVisitor(NamedDecl *DependentDecl) : DependentDecl(DependentDecl) {}
+  InstantiatedDeclVisitor(NamedDecl *DependentDecl)
+      : DependentDecl(DependentDecl) {}
 
   bool shouldVisitTemplateInstantiations() const { return true; }
 
@@ -64,8 +65,7 @@ struct InstantiatedDeclVisitor : RecursiveASTVisitor<InstantiatedDeclVisitor> {
 
   bool shouldVisitImplicitCode() const { return true; }
 
-  template <typename D>
-  bool onDeclVisited(D *MaybeInstantiated) {
+  template <typename D> bool onDeclVisited(D *MaybeInstantiated) {
     if (MaybeInstantiated->getDeclContext()->isDependentContext())
       return true;
     auto *Dependent = dyn_cast<D>(DependentDecl);
@@ -82,13 +82,9 @@ struct InstantiatedDeclVisitor : RecursiveASTVisitor<InstantiatedDeclVisitor> {
     return false;
   }
 
-  bool VisitFieldDecl(FieldDecl *FD) {
-    return onDeclVisited(FD);
-  }
+  bool VisitFieldDecl(FieldDecl *FD) { return onDeclVisited(FD); }
 
-  bool VisitVarDecl(VarDecl *VD) {
-    return onDeclVisited(VD);
-  }
+  bool VisitVarDecl(VarDecl *VD) { return onDeclVisited(VD); }
 
   NamedDecl *DependentDecl;
   QualType DeducedType;

>From 09f0acb3bdfe64cbc09116cd318f56cef242cc7c Mon Sep 17 00:00:00 2001
From: Younan Zhang <zyn7109 at gmail.com>
Date: Sun, 24 Dec 2023 18:07:34 +0800
Subject: [PATCH 4/4] wip

---
 clang-tools-extra/clangd/AST.cpp              | 195 +++++++++++++++++-
 clang-tools-extra/clangd/AST.h                |   9 +-
 clang-tools-extra/clangd/FindTarget.cpp       |  77 +++----
 clang-tools-extra/clangd/FindTarget.h         |  12 +-
 .../clangd/HeuristicResolver.cpp              |  19 +-
 clang-tools-extra/clangd/HeuristicResolver.h  |   6 +-
 clang-tools-extra/clangd/Hover.cpp            |   5 +-
 clang-tools-extra/clangd/InlayHints.cpp       |   4 +-
 clang-tools-extra/clangd/ParsedAST.cpp        |   1 -
 clang-tools-extra/clangd/ParsedAST.h          |   7 +-
 .../clangd/SemanticHighlighting.cpp           |  20 +-
 clang-tools-extra/clangd/XRefs.cpp            |  33 +--
 clang-tools-extra/clangd/refactor/Rename.cpp  |   2 +-
 .../clangd/refactor/tweaks/DefineInline.cpp   |   4 +-
 .../clangd/refactor/tweaks/DefineOutline.cpp  |   2 +-
 .../refactor/tweaks/ExtractFunction.cpp       |   2 +-
 .../clangd/unittests/ASTTests.cpp             |  85 ++++++++
 .../clangd/unittests/FindTargetTests.cpp      |   4 +-
 .../clangd/unittests/XRefsTests.cpp           |  64 +++---
 19 files changed, 416 insertions(+), 135 deletions(-)

diff --git a/clang-tools-extra/clangd/AST.cpp b/clang-tools-extra/clangd/AST.cpp
index 5b81ec213ff984..d42dbf316364a6 100644
--- a/clang-tools-extra/clangd/AST.cpp
+++ b/clang-tools-extra/clangd/AST.cpp
@@ -16,12 +16,14 @@
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/DeclTemplate.h"
+#include "clang/AST/DeclVisitor.h"
 #include "clang/AST/DeclarationName.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/PrettyPrinter.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/Stmt.h"
+#include "clang/AST/StmtVisitor.h"
 #include "clang/AST/TemplateBase.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/Basic/Builtins.h"
@@ -636,7 +638,7 @@ static NamedDecl *getOnlyInstantiationImpl(TemplateDeclTy *TD) {
   return Only;
 }
 
-NamedDecl *getOnlyInstantiation(NamedDecl *TemplatedDecl) {
+NamedDecl *getOnlyInstantiation(const NamedDecl *TemplatedDecl) {
   if (TemplateDecl *TD = TemplatedDecl->getDescribedTemplate()) {
     if (auto *CTD = llvm::dyn_cast<ClassTemplateDecl>(TD))
       return getOnlyInstantiationImpl(CTD);
@@ -648,6 +650,197 @@ NamedDecl *getOnlyInstantiation(NamedDecl *TemplatedDecl) {
   return nullptr;
 }
 
+NamedDecl *getOnlyInstantiatedDecls(const NamedDecl *DependentDecl) {
+  if (auto *Instantiation = getOnlyInstantiation(DependentDecl))
+    return Instantiation;
+  NamedDecl *OuterTemplate = nullptr;
+  for (auto *DC = DependentDecl->getDeclContext(); isa<CXXRecordDecl>(DC);
+       DC = DC->getParent()) {
+    auto *RD = cast<CXXRecordDecl>(DC);
+    if (auto *I = getOnlyInstantiation(RD)) {
+      OuterTemplate = I;
+      break;
+    }
+  }
+
+  if (!OuterTemplate)
+    return nullptr;
+
+  struct Visitor : DeclVisitor<Visitor, NamedDecl *> {
+    const NamedDecl *TemplatedDecl;
+    Visitor(const NamedDecl *TemplatedDecl) : TemplatedDecl(TemplatedDecl) {}
+
+    NamedDecl *VisitCXXRecordDecl(CXXRecordDecl *RD) {
+      if (RD->getTemplateInstantiationPattern() == TemplatedDecl)
+        return RD;
+      for (auto *F : RD->decls()) {
+        if (auto *Injected = llvm::dyn_cast<CXXRecordDecl>(F);
+            Injected && Injected->isInjectedClassName())
+          continue;
+        if (NamedDecl *ND = Visit(F))
+          return ND;
+      }
+      return nullptr;
+    }
+
+    NamedDecl *VisitClassTemplateDecl(ClassTemplateDecl *CTD) {
+      unsigned Size = llvm::range_size(CTD->specializations());
+      if (Size != 1)
+        return nullptr;
+      return Visit(*CTD->spec_begin());
+    }
+
+    NamedDecl *VisitFunctionTemplateDecl(FunctionTemplateDecl *FTD) {
+      unsigned Size = llvm::range_size(FTD->specializations());
+      if (Size != 1)
+        return nullptr;
+      return Visit(*FTD->spec_begin());
+    }
+
+    NamedDecl *VisitFunctionDecl(FunctionDecl *FD) {
+      if (FD->getTemplateInstantiationPattern() == TemplatedDecl)
+        return FD;
+      return nullptr;
+    }
+
+    NamedDecl *VisitVarDecl(VarDecl *VD) {
+      if (VD->getCanonicalDecl()->getSourceRange() ==
+          TemplatedDecl->getCanonicalDecl()->getSourceRange())
+        return VD;
+      return nullptr;
+    }
+
+    NamedDecl *VisitFieldDecl(FieldDecl *FD) {
+      if (FD->getCanonicalDecl()->getSourceRange() ==
+          TemplatedDecl->getCanonicalDecl()->getSourceRange())
+        return FD;
+      return nullptr;
+    }
+  };
+  return Visitor(DependentDecl).Visit(OuterTemplate);
+}
+
+std::optional<DynTypedNode>
+getOnlyInstantiatedNode(const DeclContext *StartingPoint,
+                        const DynTypedNode &DependentNode) {
+  if (auto *CTD = DependentNode.get<ClassTemplateDecl>())
+    return getOnlyInstantiatedNode(
+        StartingPoint, DynTypedNode::create(*CTD->getTemplatedDecl()));
+  if (auto *FTD = DependentNode.get<FunctionTemplateDecl>())
+    return getOnlyInstantiatedNode(
+        StartingPoint, DynTypedNode::create(*FTD->getTemplatedDecl()));
+
+  if (auto *FD = DependentNode.get<FunctionDecl>()) {
+    auto *ID = getOnlyInstantiatedDecls(FD);
+    if (!ID)
+      return std::nullopt;
+    return DynTypedNode::create(*ID);
+  }
+  if (auto *RD = DependentNode.get<CXXRecordDecl>()) {
+    auto *ID = getOnlyInstantiatedDecls(RD);
+    if (!ID)
+      return std::nullopt;
+    return DynTypedNode::create(*ID);
+  }
+
+  NamedDecl *InstantiatedEnclosingDecl = nullptr;
+  for (auto *DC = StartingPoint; DC;
+       DC = DC->getParent()) {
+    auto *ND = llvm::dyn_cast<NamedDecl>(DC);
+    if (!ND)
+      continue;
+    InstantiatedEnclosingDecl = getOnlyInstantiatedDecls(ND);
+    if (InstantiatedEnclosingDecl)
+      break;
+  }
+
+  if (!InstantiatedEnclosingDecl)
+    return std::nullopt;
+
+  auto *InstantiatedFunctionDecl =
+      llvm::dyn_cast<FunctionDecl>(InstantiatedEnclosingDecl);
+  if (!InstantiatedFunctionDecl)
+    return std::nullopt;
+
+  struct FullExprVisitor : RecursiveASTVisitor<FullExprVisitor> {
+    const DynTypedNode &DependentNode;
+    Stmt *Result;
+    FullExprVisitor(const DynTypedNode &DependentNode)
+        : DependentNode(DependentNode), Result(nullptr) {}
+
+    bool shouldVisitTemplateInstantiations() const { return true; }
+
+    bool shouldVisitImplicitCode() const { return true; }
+
+    bool VisitStmt(Stmt *S) {
+      if (S->getSourceRange() == DependentNode.getSourceRange()) {
+        Result = S;
+        return false;
+      }
+      return true;
+    }
+  };
+
+  FullExprVisitor Visitor(DependentNode);
+  Visitor.TraverseFunctionDecl(InstantiatedFunctionDecl);
+  if (Visitor.Result)
+    return DynTypedNode::create(*Visitor.Result);
+  return std::nullopt;
+}
+
+NamedDecl *
+getOnlyInstantiationForMemberFunction(const CXXMethodDecl *TemplatedDecl) {
+  if (auto *MemberInstantiation = getOnlyInstantiation(TemplatedDecl))
+    return MemberInstantiation;
+  NamedDecl *OuterTemplate = nullptr;
+  for (auto *DC = TemplatedDecl->getDeclContext(); isa<CXXRecordDecl>(DC);
+       DC = DC->getParent()) {
+    auto *RD = cast<CXXRecordDecl>(DC);
+    if (auto *I = getOnlyInstantiation(RD)) {
+      OuterTemplate = I;
+      break;
+    }
+  }
+  if (!OuterTemplate)
+    return nullptr;
+  struct Visitor : DeclVisitor<Visitor, NamedDecl *> {
+    const CXXMethodDecl *TD;
+    Visitor(const CXXMethodDecl *TemplatedDecl) : TD(TemplatedDecl) {}
+    NamedDecl *VisitCXXRecordDecl(CXXRecordDecl *RD) {
+      for (auto *F : RD->decls()) {
+        if (!isa<NamedDecl>(F))
+          continue;
+        if (NamedDecl *ND = Visit(F))
+          return ND;
+      }
+      return nullptr;
+    }
+
+    NamedDecl *VisitClassTemplateDecl(ClassTemplateDecl *CTD) {
+      unsigned Size = llvm::range_size(CTD->specializations());
+      if (Size != 1)
+        return nullptr;
+      return Visit(*CTD->spec_begin());
+    }
+
+    NamedDecl *VisitFunctionTemplateDecl(FunctionTemplateDecl *FTD) {
+      unsigned Size = llvm::range_size(FTD->specializations());
+      if (Size != 1)
+        return nullptr;
+      return Visit(*FTD->spec_begin());
+    }
+
+    NamedDecl *VisitCXXMethodDecl(CXXMethodDecl *MD) {
+      auto *Pattern = MD->getTemplateInstantiationPattern();
+      if (Pattern == TD)
+        return MD;
+      return nullptr;
+    }
+
+  };
+  return Visitor(TemplatedDecl).Visit(OuterTemplate);
+}
+
 std::vector<const Attr *> getAttributes(const DynTypedNode &N) {
   std::vector<const Attr *> Result;
   if (const auto *TL = N.get<TypeLoc>()) {
diff --git a/clang-tools-extra/clangd/AST.h b/clang-tools-extra/clangd/AST.h
index fb0722d697cd06..74a43fdce26903 100644
--- a/clang-tools-extra/clangd/AST.h
+++ b/clang-tools-extra/clangd/AST.h
@@ -177,7 +177,14 @@ TemplateTypeParmTypeLoc getContainedAutoParamType(TypeLoc TL);
 
 // If TemplatedDecl is the generic body of a template, and the template has
 // exactly one visible instantiation, return the instantiated body.
-NamedDecl *getOnlyInstantiation(NamedDecl *TemplatedDecl);
+NamedDecl *getOnlyInstantiation(const NamedDecl *TemplatedDecl);
+
+NamedDecl *
+getOnlyInstantiationForMemberFunction(const CXXMethodDecl *TemplatedDecl);
+
+std::optional<DynTypedNode>
+getOnlyInstantiatedNode(const DeclContext *StartingPoint,
+                        const DynTypedNode &DependentNode);
 
 /// Return attributes attached directly to a node.
 std::vector<const Attr *> getAttributes(const DynTypedNode &);
diff --git a/clang-tools-extra/clangd/FindTarget.cpp b/clang-tools-extra/clangd/FindTarget.cpp
index 839cf6332fe8b0..0045c4e64c6f7a 100644
--- a/clang-tools-extra/clangd/FindTarget.cpp
+++ b/clang-tools-extra/clangd/FindTarget.cpp
@@ -133,7 +133,7 @@ struct TargetFinder {
   using Rel = DeclRelation;
 
 private:
-  const HeuristicResolver *Resolver;
+  const HeuristicResolver &Resolver;
   llvm::SmallDenseMap<const NamedDecl *,
                       std::pair<RelSet, /*InsertionOrder*/ size_t>>
       Decls;
@@ -153,7 +153,7 @@ struct TargetFinder {
   }
 
 public:
-  TargetFinder(const HeuristicResolver *Resolver) : Resolver(Resolver) {}
+  TargetFinder(const HeuristicResolver &Resolver) : Resolver(Resolver) {}
 
   llvm::SmallVector<std::pair<const NamedDecl *, RelSet>, 1> takeDecls() const {
     using ValTy = std::pair<const NamedDecl *, RelSet>;
@@ -197,10 +197,8 @@ struct TargetFinder {
       Flags |= Rel::Alias; // continue with the alias
     } else if (const UnresolvedUsingValueDecl *UUVD =
                    dyn_cast<UnresolvedUsingValueDecl>(D)) {
-      if (Resolver) {
-        for (const NamedDecl *Target : Resolver->resolveUsingValueDecl(UUVD)) {
-          add(Target, Flags); // no Underlying as this is a non-renaming alias
-        }
+      for (const NamedDecl *Target : Resolver.resolveUsingValueDecl(UUVD)) {
+        add(Target, Flags); // no Underlying as this is a non-renaming alias
       }
       Flags |= Rel::Alias; // continue with the alias
     } else if (isa<UnresolvedUsingTypenameDecl>(D)) {
@@ -304,17 +302,13 @@ struct TargetFinder {
       }
       void
       VisitCXXDependentScopeMemberExpr(const CXXDependentScopeMemberExpr *E) {
-        if (Outer.Resolver) {
-          for (const NamedDecl *D : Outer.Resolver->resolveMemberExpr(E)) {
-            Outer.add(D, Flags);
-          }
+        for (const NamedDecl *D : Outer.Resolver.resolveMemberExpr(E)) {
+          Outer.add(D, Flags);
         }
       }
       void VisitDependentScopeDeclRefExpr(const DependentScopeDeclRefExpr *E) {
-        if (Outer.Resolver) {
-          for (const NamedDecl *D : Outer.Resolver->resolveDeclRefExpr(E)) {
-            Outer.add(D, Flags);
-          }
+        for (const NamedDecl *D : Outer.Resolver.resolveDeclRefExpr(E)) {
+          Outer.add(D, Flags);
         }
       }
       void VisitObjCIvarRefExpr(const ObjCIvarRefExpr *OIRE) {
@@ -407,20 +401,16 @@ struct TargetFinder {
           Outer.add(TD->getTemplatedDecl(), Flags | Rel::TemplatePattern);
       }
       void VisitDependentNameType(const DependentNameType *DNT) {
-        if (Outer.Resolver) {
-          for (const NamedDecl *ND :
-               Outer.Resolver->resolveDependentNameType(DNT)) {
-            Outer.add(ND, Flags);
-          }
+        for (const NamedDecl *ND :
+             Outer.Resolver.resolveDependentNameType(DNT)) {
+          Outer.add(ND, Flags);
         }
       }
       void VisitDependentTemplateSpecializationType(
           const DependentTemplateSpecializationType *DTST) {
-        if (Outer.Resolver) {
-          for (const NamedDecl *ND :
-               Outer.Resolver->resolveTemplateSpecializationType(DTST)) {
-            Outer.add(ND, Flags);
-          }
+        for (const NamedDecl *ND :
+             Outer.Resolver.resolveTemplateSpecializationType(DTST)) {
+          Outer.add(ND, Flags);
         }
       }
       void VisitTypedefType(const TypedefType *TT) {
@@ -488,10 +478,7 @@ struct TargetFinder {
       add(NNS->getAsNamespaceAlias(), Flags);
       return;
     case NestedNameSpecifier::Identifier:
-      if (Resolver) {
-        add(QualType(Resolver->resolveNestedNameSpecifierToType(NNS), 0),
-            Flags);
-      }
+      add(QualType(Resolver.resolveNestedNameSpecifierToType(NNS), 0), Flags);
       return;
     case NestedNameSpecifier::TypeSpec:
     case NestedNameSpecifier::TypeSpecWithTemplate:
@@ -542,7 +529,7 @@ struct TargetFinder {
 } // namespace
 
 llvm::SmallVector<std::pair<const NamedDecl *, DeclRelationSet>, 1>
-allTargetDecls(const DynTypedNode &N, const HeuristicResolver *Resolver) {
+allTargetDecls(const DynTypedNode &N, const HeuristicResolver &Resolver) {
   dlog("allTargetDecls({0})", nodeToString(N));
   TargetFinder Finder(Resolver);
   DeclRelationSet Flags;
@@ -573,7 +560,7 @@ allTargetDecls(const DynTypedNode &N, const HeuristicResolver *Resolver) {
 
 llvm::SmallVector<const NamedDecl *, 1>
 targetDecl(const DynTypedNode &N, DeclRelationSet Mask,
-           const HeuristicResolver *Resolver) {
+           const HeuristicResolver &Resolver) {
   llvm::SmallVector<const NamedDecl *, 1> Result;
   for (const auto &Entry : allTargetDecls(N, Resolver)) {
     if (!(Entry.second & ~Mask))
@@ -584,7 +571,7 @@ targetDecl(const DynTypedNode &N, DeclRelationSet Mask,
 
 llvm::SmallVector<const NamedDecl *, 1>
 explicitReferenceTargets(DynTypedNode N, DeclRelationSet Mask,
-                         const HeuristicResolver *Resolver) {
+                         const HeuristicResolver &Resolver) {
   assert(!(Mask & (DeclRelation::TemplatePattern |
                    DeclRelation::TemplateInstantiation)) &&
          "explicitReferenceTargets handles templates on its own");
@@ -616,11 +603,11 @@ explicitReferenceTargets(DynTypedNode N, DeclRelationSet Mask,
 
 namespace {
 llvm::SmallVector<ReferenceLoc> refInDecl(const Decl *D,
-                                          const HeuristicResolver *Resolver) {
+                                          const HeuristicResolver &Resolver) {
   struct Visitor : ConstDeclVisitor<Visitor> {
-    Visitor(const HeuristicResolver *Resolver) : Resolver(Resolver) {}
+    Visitor(const HeuristicResolver &Resolver) : Resolver(Resolver) {}
 
-    const HeuristicResolver *Resolver;
+    const HeuristicResolver &Resolver;
     llvm::SmallVector<ReferenceLoc> Refs;
 
     void VisitUsingDirectiveDecl(const UsingDirectiveDecl *D) {
@@ -741,11 +728,11 @@ llvm::SmallVector<ReferenceLoc> refInDecl(const Decl *D,
 }
 
 llvm::SmallVector<ReferenceLoc> refInStmt(const Stmt *S,
-                                          const HeuristicResolver *Resolver) {
+                                          const HeuristicResolver &Resolver) {
   struct Visitor : ConstStmtVisitor<Visitor> {
-    Visitor(const HeuristicResolver *Resolver) : Resolver(Resolver) {}
+    Visitor(const HeuristicResolver &Resolver) : Resolver(Resolver) {}
 
-    const HeuristicResolver *Resolver;
+    const HeuristicResolver &Resolver;
     // FIXME: handle more complicated cases: more ObjC, designated initializers.
     llvm::SmallVector<ReferenceLoc> Refs;
 
@@ -852,11 +839,11 @@ llvm::SmallVector<ReferenceLoc> refInStmt(const Stmt *S,
 }
 
 llvm::SmallVector<ReferenceLoc>
-refInTypeLoc(TypeLoc L, const HeuristicResolver *Resolver) {
+refInTypeLoc(TypeLoc L, const HeuristicResolver &Resolver) {
   struct Visitor : TypeLocVisitor<Visitor> {
-    Visitor(const HeuristicResolver *Resolver) : Resolver(Resolver) {}
+    Visitor(const HeuristicResolver &Resolver) : Resolver(Resolver) {}
 
-    const HeuristicResolver *Resolver;
+    const HeuristicResolver &Resolver;
     llvm::SmallVector<ReferenceLoc> Refs;
 
     void VisitElaboratedTypeLoc(ElaboratedTypeLoc L) {
@@ -966,7 +953,7 @@ class ExplicitReferenceCollector
     : public RecursiveASTVisitor<ExplicitReferenceCollector> {
 public:
   ExplicitReferenceCollector(llvm::function_ref<void(ReferenceLoc)> Out,
-                             const HeuristicResolver *Resolver)
+                             const HeuristicResolver &Resolver)
       : Out(Out), Resolver(Resolver) {
     assert(Out);
   }
@@ -1144,7 +1131,7 @@ class ExplicitReferenceCollector
   }
 
   llvm::function_ref<void(ReferenceLoc)> Out;
-  const HeuristicResolver *Resolver;
+  const HeuristicResolver &Resolver;
   /// TypeLocs starting at these locations must be skipped, see
   /// TraverseElaboratedTypeSpecifierLoc for details.
   llvm::DenseSet<SourceLocation> TypeLocsToSkip;
@@ -1153,19 +1140,19 @@ class ExplicitReferenceCollector
 
 void findExplicitReferences(const Stmt *S,
                             llvm::function_ref<void(ReferenceLoc)> Out,
-                            const HeuristicResolver *Resolver) {
+                            const HeuristicResolver &Resolver) {
   assert(S);
   ExplicitReferenceCollector(Out, Resolver).TraverseStmt(const_cast<Stmt *>(S));
 }
 void findExplicitReferences(const Decl *D,
                             llvm::function_ref<void(ReferenceLoc)> Out,
-                            const HeuristicResolver *Resolver) {
+                            const HeuristicResolver &Resolver) {
   assert(D);
   ExplicitReferenceCollector(Out, Resolver).TraverseDecl(const_cast<Decl *>(D));
 }
 void findExplicitReferences(const ASTContext &AST,
                             llvm::function_ref<void(ReferenceLoc)> Out,
-                            const HeuristicResolver *Resolver) {
+                            const HeuristicResolver &Resolver) {
   ExplicitReferenceCollector(Out, Resolver)
       .TraverseAST(const_cast<ASTContext &>(AST));
 }
diff --git a/clang-tools-extra/clangd/FindTarget.h b/clang-tools-extra/clangd/FindTarget.h
index b41c5470951001..b5d10ea8be189e 100644
--- a/clang-tools-extra/clangd/FindTarget.h
+++ b/clang-tools-extra/clangd/FindTarget.h
@@ -81,14 +81,14 @@ class DeclRelationSet;
 /// FIXME: some AST nodes cannot be DynTypedNodes, these cannot be specified.
 llvm::SmallVector<const NamedDecl *, 1>
 targetDecl(const DynTypedNode &, DeclRelationSet Mask,
-           const HeuristicResolver *Resolver);
+           const HeuristicResolver &Resolver);
 
 /// Similar to targetDecl(), however instead of applying a filter, all possible
 /// decls are returned along with their DeclRelationSets.
 /// This is suitable for indexing, where everything is recorded and filtering
 /// is applied later.
 llvm::SmallVector<std::pair<const NamedDecl *, DeclRelationSet>, 1>
-allTargetDecls(const DynTypedNode &, const HeuristicResolver *);
+allTargetDecls(const DynTypedNode &, const HeuristicResolver &);
 
 enum class DeclRelation : unsigned {
   // Template options apply when the declaration is an instantiated template.
@@ -147,13 +147,13 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, ReferenceLoc R);
 /// FIXME: extend to report location information about declaration names too.
 void findExplicitReferences(const Stmt *S,
                             llvm::function_ref<void(ReferenceLoc)> Out,
-                            const HeuristicResolver *Resolver);
+                            const HeuristicResolver &Resolver);
 void findExplicitReferences(const Decl *D,
                             llvm::function_ref<void(ReferenceLoc)> Out,
-                            const HeuristicResolver *Resolver);
+                            const HeuristicResolver &Resolver);
 void findExplicitReferences(const ASTContext &AST,
                             llvm::function_ref<void(ReferenceLoc)> Out,
-                            const HeuristicResolver *Resolver);
+                            const HeuristicResolver &Resolver);
 
 /// Find declarations explicitly referenced in the source code defined by \p N.
 /// For templates, will prefer to return a template instantiation whenever
@@ -166,7 +166,7 @@ void findExplicitReferences(const ASTContext &AST,
 /// \p Mask should not contain TemplatePattern or TemplateInstantiation.
 llvm::SmallVector<const NamedDecl *, 1>
 explicitReferenceTargets(DynTypedNode N, DeclRelationSet Mask,
-                         const HeuristicResolver *Resolver);
+                         const HeuristicResolver &Resolver);
 
 // Boring implementation details of bitfield.
 
diff --git a/clang-tools-extra/clangd/HeuristicResolver.cpp b/clang-tools-extra/clangd/HeuristicResolver.cpp
index 90de811b10330d..072ff4cf13d105 100644
--- a/clang-tools-extra/clangd/HeuristicResolver.cpp
+++ b/clang-tools-extra/clangd/HeuristicResolver.cpp
@@ -9,6 +9,7 @@
 #include "HeuristicResolver.h"
 #include "AST.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/ASTTypeTraits.h"
 #include "clang/AST/CXXInheritance.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
@@ -38,7 +39,7 @@ const auto TemplateFilter = [](const NamedDecl *D) {
 namespace {
 
 const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
-                               ASTContext &Ctx) {
+                               const ASTContext &Ctx) {
   if (Decls.size() != 1) // Names an overload set -- just bail.
     return nullptr;
   if (const auto *TD = dyn_cast<TypeDecl>(Decls[0])) {
@@ -93,7 +94,19 @@ struct InstantiatedDeclVisitor : RecursiveASTVisitor<InstantiatedDeclVisitor> {
 /// Attempt to resolve the dependent type from the surrounding context for which
 /// a single instantiation is available.
 const Type *
-resolveTypeFromInstantiatedTemplate(const CXXDependentScopeMemberExpr *Expr) {
+resolveTypeFromInstantiatedTemplate(const DeclContext *DC,
+                                    const CXXDependentScopeMemberExpr *Expr) {
+
+  std::optional<DynTypedNode> Node =
+      getOnlyInstantiatedNode(DC, DynTypedNode::create(*Expr));
+  if (!Node)
+    return nullptr;
+
+  if (auto *ME = Node->get<MemberExpr>())
+    return ME->getBase()->getType().getTypePtrOrNull();
+
+  return nullptr;
+
   if (Expr->isImplicitAccess())
     return nullptr;
 
@@ -247,7 +260,7 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
     return {};
 
   if (BaseType->isDependentType())
-    if (auto *MaybeResolved = resolveTypeFromInstantiatedTemplate(ME))
+    if (auto *MaybeResolved = resolveTypeFromInstantiatedTemplate(EnclosingDecl, ME))
       BaseType = MaybeResolved;
 
   if (const auto *BT = BaseType->getAs<BuiltinType>()) {
diff --git a/clang-tools-extra/clangd/HeuristicResolver.h b/clang-tools-extra/clangd/HeuristicResolver.h
index dc04123d37593c..16650e503a1d0b 100644
--- a/clang-tools-extra/clangd/HeuristicResolver.h
+++ b/clang-tools-extra/clangd/HeuristicResolver.h
@@ -45,7 +45,8 @@ namespace clangd {
 // not a specialization. More advanced heuristics may be added in the future.
 class HeuristicResolver {
 public:
-  HeuristicResolver(ASTContext &Ctx) : Ctx(Ctx) {}
+  HeuristicResolver(const ASTContext &Ctx, const DeclContext *EnclosingDecl = nullptr)
+      : Ctx(Ctx), EnclosingDecl(EnclosingDecl) {}
 
   // Try to heuristically resolve certain types of expressions, declarations, or
   // types to one or more likely-referenced declarations.
@@ -76,7 +77,8 @@ class HeuristicResolver {
   const Type *getPointeeType(const Type *T) const;
 
 private:
-  ASTContext &Ctx;
+  const ASTContext &Ctx;
+  const DeclContext *EnclosingDecl;
 
   // Given a tag-decl type and a member name, heuristically resolve the
   // name to one or more declarations.
diff --git a/clang-tools-extra/clangd/Hover.cpp b/clang-tools-extra/clangd/Hover.cpp
index a868d3bb4e3fa1..32b34d1703c311 100644
--- a/clang-tools-extra/clangd/Hover.cpp
+++ b/clang-tools-extra/clangd/Hover.cpp
@@ -1358,8 +1358,9 @@ std::optional<HoverInfo> getHover(ParsedAST &AST, Position Pos,
         SelectionTree::createRight(AST.getASTContext(), TB, Offset, Offset);
     if (const SelectionTree::Node *N = ST.commonAncestor()) {
       // FIXME: Fill in HighlightRange with range coming from N->ASTNode.
-      auto Decls = explicitReferenceTargets(N->ASTNode, DeclRelation::Alias,
-                                            AST.getHeuristicResolver());
+      auto Decls = explicitReferenceTargets(
+          N->ASTNode, DeclRelation::Alias,
+          AST.getHeuristicResolver(&N->getDeclContext()));
       if (const auto *DeclToUse = pickDeclToUse(Decls)) {
         HoverCountMetric.record(1, "decl");
         HI = getHoverContents(DeclToUse, PP, Index, TB);
diff --git a/clang-tools-extra/clangd/InlayHints.cpp b/clang-tools-extra/clangd/InlayHints.cpp
index b540c273cbd596..fc08d29cd3776c 100644
--- a/clang-tools-extra/clangd/InlayHints.cpp
+++ b/clang-tools-extra/clangd/InlayHints.cpp
@@ -627,7 +627,7 @@ class InlayHintVisitor : public RecursiveASTVisitor<InlayHintVisitor> {
         isa<UserDefinedLiteral>(E))
       return true;
 
-    auto CalleeDecls = Resolver->resolveCalleeOfCallExpr(E);
+    auto CalleeDecls = Resolver.resolveCalleeOfCallExpr(E);
     if (CalleeDecls.size() != 1)
       return true;
 
@@ -1274,7 +1274,7 @@ class InlayHintVisitor : public RecursiveASTVisitor<InlayHintVisitor> {
   std::optional<Range> RestrictRange;
   FileID MainFileID;
   StringRef MainFileBuf;
-  const HeuristicResolver *Resolver;
+  HeuristicResolver Resolver;
   PrintingPolicy TypeHintPolicy;
 };
 
diff --git a/clang-tools-extra/clangd/ParsedAST.cpp b/clang-tools-extra/clangd/ParsedAST.cpp
index edd0f77b1031ef..741e9e96f3437b 100644
--- a/clang-tools-extra/clangd/ParsedAST.cpp
+++ b/clang-tools-extra/clangd/ParsedAST.cpp
@@ -835,7 +835,6 @@ ParsedAST::ParsedAST(PathRef TUPath, llvm::StringRef Version,
       Marks(std::move(Marks)), Diags(std::move(Diags)),
       LocalTopLevelDecls(std::move(LocalTopLevelDecls)),
       Includes(std::move(Includes)) {
-  Resolver = std::make_unique<HeuristicResolver>(getASTContext());
   assert(this->Clang);
   assert(this->Action);
 }
diff --git a/clang-tools-extra/clangd/ParsedAST.h b/clang-tools-extra/clangd/ParsedAST.h
index c68fdba6bd26cd..1f1425a9950053 100644
--- a/clang-tools-extra/clangd/ParsedAST.h
+++ b/clang-tools-extra/clangd/ParsedAST.h
@@ -30,6 +30,7 @@
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Tooling/Syntax/Tokens.h"
+#include "HeuristicResolver.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringRef.h"
 #include <memory>
@@ -118,8 +119,9 @@ class ParsedAST {
   /// AST. Might be std::nullopt if no Preamble is used.
   std::optional<llvm::StringRef> preambleVersion() const;
 
-  const HeuristicResolver *getHeuristicResolver() const {
-    return Resolver.get();
+  HeuristicResolver
+  getHeuristicResolver(const DeclContext *EnclosingDecl = nullptr) const {
+    return HeuristicResolver(getASTContext(), EnclosingDecl);
   }
 
 private:
@@ -159,7 +161,6 @@ class ParsedAST {
   // top-level decls from the preamble.
   std::vector<Decl *> LocalTopLevelDecls;
   IncludeStructure Includes;
-  std::unique_ptr<HeuristicResolver> Resolver;
 };
 
 } // namespace clangd
diff --git a/clang-tools-extra/clangd/SemanticHighlighting.cpp b/clang-tools-extra/clangd/SemanticHighlighting.cpp
index 49e479abf45621..d5cd8a3c089a89 100644
--- a/clang-tools-extra/clangd/SemanticHighlighting.cpp
+++ b/clang-tools-extra/clangd/SemanticHighlighting.cpp
@@ -95,9 +95,9 @@ bool isUniqueDefinition(const NamedDecl *Decl) {
 }
 
 std::optional<HighlightingKind> kindForType(const Type *TP,
-                                            const HeuristicResolver *Resolver);
+                                            const HeuristicResolver &Resolver);
 std::optional<HighlightingKind> kindForDecl(const NamedDecl *D,
-                                            const HeuristicResolver *Resolver) {
+                                            const HeuristicResolver &Resolver) {
   if (auto *USD = dyn_cast<UsingShadowDecl>(D)) {
     if (auto *Target = USD->getTargetDecl())
       D = Target;
@@ -169,7 +169,7 @@ std::optional<HighlightingKind> kindForDecl(const NamedDecl *D,
   if (isa<LabelDecl>(D))
     return HighlightingKind::Label;
   if (const auto *UUVD = dyn_cast<UnresolvedUsingValueDecl>(D)) {
-    auto Targets = Resolver->resolveUsingValueDecl(UUVD);
+    auto Targets = Resolver.resolveUsingValueDecl(UUVD);
     if (!Targets.empty() && Targets[0] != UUVD) {
       return kindForDecl(Targets[0], Resolver);
     }
@@ -178,7 +178,7 @@ std::optional<HighlightingKind> kindForDecl(const NamedDecl *D,
   return std::nullopt;
 }
 std::optional<HighlightingKind> kindForType(const Type *TP,
-                                            const HeuristicResolver *Resolver) {
+                                            const HeuristicResolver &Resolver) {
   if (!TP)
     return std::nullopt;
   if (TP->isBuiltinType()) // Builtins are special, they do not have decls.
@@ -416,9 +416,11 @@ class HighlightingFilter {
 /// Consumes source locations and maps them to text ranges for highlightings.
 class HighlightingsBuilder {
 public:
-  HighlightingsBuilder(const ParsedAST &AST, const HighlightingFilter &Filter)
+  HighlightingsBuilder(const ParsedAST &AST, const HighlightingFilter &Filter,
+                       const HeuristicResolver &Resolver)
       : TB(AST.getTokens()), SourceMgr(AST.getSourceManager()),
-        LangOpts(AST.getLangOpts()), Filter(Filter) {}
+        LangOpts(AST.getLangOpts()), Filter(Filter),
+        Resolver(Resolver) {}
 
   HighlightingToken &addToken(SourceLocation Loc, HighlightingKind Kind) {
     auto Range = getRangeForSourceLocation(Loc);
@@ -567,7 +569,7 @@ class HighlightingsBuilder {
     return WithInactiveLines;
   }
 
-  const HeuristicResolver *getResolver() const { return Resolver; }
+  const HeuristicResolver &getResolver() const { return Resolver; }
 
 private:
   std::optional<Range> getRangeForSourceLocation(SourceLocation Loc) {
@@ -589,7 +591,7 @@ class HighlightingsBuilder {
   HighlightingFilter Filter;
   std::vector<HighlightingToken> Tokens;
   std::map<Range, llvm::SmallVector<HighlightingModifier, 1>> ExtraModifiers;
-  const HeuristicResolver *Resolver = nullptr;
+  const HeuristicResolver &Resolver;
   // returned from addToken(InvalidLoc)
   HighlightingToken InvalidHighlightingToken;
 };
@@ -1150,7 +1152,7 @@ getSemanticHighlightings(ParsedAST &AST, bool IncludeInactiveRegionTokens) {
   if (!IncludeInactiveRegionTokens)
     Filter.disableKind(HighlightingKind::InactiveCode);
   // Add highlightings for AST nodes.
-  HighlightingsBuilder Builder(AST, Filter);
+  HighlightingsBuilder Builder(AST, Filter, AST.getHeuristicResolver());
   // Highlight 'decltype' and 'auto' as their underlying types.
   CollectExtraHighlightings(Builder).TraverseAST(C);
   // Highlight all decls and references coming from the AST.
diff --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp
index f55d2a56239563..835cc9292fd922 100644
--- a/clang-tools-extra/clangd/XRefs.cpp
+++ b/clang-tools-extra/clangd/XRefs.cpp
@@ -188,7 +188,8 @@ getDeclAtPositionWithRelations(ParsedAST &AST, SourceLocation Pos,
       // This makes the `override` hack work.
       if (N->ASTNode.get<Attr>() && N->Parent)
         N = N->Parent;
-      llvm::copy_if(allTargetDecls(N->ASTNode, AST.getHeuristicResolver()),
+      llvm::copy_if(allTargetDecls(N->ASTNode, AST.getHeuristicResolver(
+                                                   &N->getDeclContext())),
                     std::back_inserter(Result),
                     [&](auto &Entry) { return !(Entry.second & ~Relations); });
     }
@@ -1243,7 +1244,8 @@ std::vector<DocumentHighlight> findDocumentHighlights(ParsedAST &AST,
       DeclRelationSet Relations =
           DeclRelation::TemplatePattern | DeclRelation::Alias;
       auto TargetDecls =
-          targetDecl(N->ASTNode, Relations, AST.getHeuristicResolver());
+          targetDecl(N->ASTNode, Relations,
+                     AST.getHeuristicResolver(&N->getDeclContext()));
       if (!TargetDecls.empty()) {
         // FIXME: we may get multiple DocumentHighlights with the same location
         // and different kinds, deduplicate them.
@@ -2014,8 +2016,8 @@ static QualType typeForNode(const SelectionTree::Node *N) {
 
 // Given a type targeted by the cursor, return one or more types that are more interesting
 // to target.
-static void unwrapFindType(
-    QualType T, const HeuristicResolver* H, llvm::SmallVector<QualType>& Out) {
+static void unwrapFindType(QualType T, const HeuristicResolver &H,
+                           llvm::SmallVector<QualType> &Out) {
   if (T.isNull())
     return;
 
@@ -2043,18 +2045,18 @@ static void unwrapFindType(
   }
 
   // For smart pointer types, add the underlying type
-  if (H)
-    if (const auto* PointeeType = H->getPointeeType(T.getNonReferenceType().getTypePtr())) {
-        unwrapFindType(QualType(PointeeType, 0), H, Out);
-        return Out.push_back(T);
-    }
+  if (const auto *PointeeType =
+          H.getPointeeType(T.getNonReferenceType().getTypePtr())) {
+    unwrapFindType(QualType(PointeeType, 0), H, Out);
+    return Out.push_back(T);
+  }
 
   return Out.push_back(T);
 }
 
 // Convenience overload, to allow calling this without the out-parameter
-static llvm::SmallVector<QualType> unwrapFindType(
-    QualType T, const HeuristicResolver* H) {
+static llvm::SmallVector<QualType> unwrapFindType(QualType T,
+                                                  const HeuristicResolver &H) {
   llvm::SmallVector<QualType> Result;
   unwrapFindType(T, H, Result);
   return Result;
@@ -2076,10 +2078,11 @@ std::vector<LocatedSymbol> findType(ParsedAST &AST, Position Pos,
     std::vector<LocatedSymbol> LocatedSymbols;
 
     // NOTE: unwrapFindType might return duplicates for something like
-    // unique_ptr<unique_ptr<T>>. Let's *not* remove them, because it gives you some
-    // information about the type you may have not known before
-    // (since unique_ptr<unique_ptr<T>> != unique_ptr<T>).
-    for (const QualType& Type : unwrapFindType(typeForNode(N), AST.getHeuristicResolver()))
+    // unique_ptr<unique_ptr<T>>. Let's *not* remove them, because it gives you
+    // some information about the type you may have not known before (since
+    // unique_ptr<unique_ptr<T>> != unique_ptr<T>).
+    for (const QualType &Type : unwrapFindType(
+             typeForNode(N), AST.getHeuristicResolver(&N->getDeclContext())))
       llvm::copy(locateSymbolForType(AST, Type, Index),
                  std::back_inserter(LocatedSymbols));
 
diff --git a/clang-tools-extra/clangd/refactor/Rename.cpp b/clang-tools-extra/clangd/refactor/Rename.cpp
index 11f9e4627af760..9f1e04ce9af84a 100644
--- a/clang-tools-extra/clangd/refactor/Rename.cpp
+++ b/clang-tools-extra/clangd/refactor/Rename.cpp
@@ -168,7 +168,7 @@ llvm::DenseSet<const NamedDecl *> locateDeclAt(ParsedAST &AST,
   for (const NamedDecl *D :
        targetDecl(SelectedNode->ASTNode,
                   DeclRelation::Alias | DeclRelation::TemplatePattern,
-                  AST.getHeuristicResolver())) {
+                  AST.getHeuristicResolver(&SelectedNode->getDeclContext()))) {
     D = pickInterestingTarget(D);
     Result.insert(canonicalRenameDecl(D));
   }
diff --git a/clang-tools-extra/clangd/refactor/tweaks/DefineInline.cpp b/clang-tools-extra/clangd/refactor/tweaks/DefineInline.cpp
index d6556bba14725c..7dc4053c401ab2 100644
--- a/clang-tools-extra/clangd/refactor/tweaks/DefineInline.cpp
+++ b/clang-tools-extra/clangd/refactor/tweaks/DefineInline.cpp
@@ -123,7 +123,7 @@ bool checkDeclsAreVisible(const llvm::DenseSet<const Decl *> &DeclRefs,
 // still valid in context of Target.
 llvm::Expected<std::string> qualifyAllDecls(const FunctionDecl *FD,
                                             const FunctionDecl *Target,
-                                            const HeuristicResolver *Resolver) {
+                                            const HeuristicResolver &Resolver) {
   // There are three types of spellings that needs to be qualified in a function
   // body:
   // - Types:       Foo                 -> ns::Foo
@@ -221,7 +221,7 @@ llvm::Expected<std::string> qualifyAllDecls(const FunctionDecl *FD,
 /// \p Dest to be the same as in \p Source.
 llvm::Expected<tooling::Replacements>
 renameParameters(const FunctionDecl *Dest, const FunctionDecl *Source,
-                 const HeuristicResolver *Resolver) {
+                 const HeuristicResolver &Resolver) {
   llvm::DenseMap<const Decl *, std::string> ParamToNewName;
   llvm::DenseMap<const NamedDecl *, std::vector<SourceLocation>> RefLocs;
   auto HandleParam = [&](const NamedDecl *DestParam,
diff --git a/clang-tools-extra/clangd/refactor/tweaks/DefineOutline.cpp b/clang-tools-extra/clangd/refactor/tweaks/DefineOutline.cpp
index b84ae04072f2c1..d44f0ef13c7729 100644
--- a/clang-tools-extra/clangd/refactor/tweaks/DefineOutline.cpp
+++ b/clang-tools-extra/clangd/refactor/tweaks/DefineOutline.cpp
@@ -181,7 +181,7 @@ deleteTokensWithKind(const syntax::TokenBuffer &TokBuf, tok::TokenKind Kind,
 llvm::Expected<std::string>
 getFunctionSourceCode(const FunctionDecl *FD, llvm::StringRef TargetNamespace,
                       const syntax::TokenBuffer &TokBuf,
-                      const HeuristicResolver *Resolver) {
+                      const HeuristicResolver &Resolver) {
   auto &AST = FD->getASTContext();
   auto &SM = AST.getSourceManager();
   auto TargetContext = findContextForNS(TargetNamespace, FD->getDeclContext());
diff --git a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
index 0302839c58252e..28c19e097ec988 100644
--- a/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
+++ b/clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -176,7 +176,7 @@ struct ExtractionZone {
   // This performs a partial AST traversal proportional to the size of the
   // enclosing function, so it is possibly expensive.
   bool requiresHoisting(const SourceManager &SM,
-                        const HeuristicResolver *Resolver) const {
+                        const HeuristicResolver &Resolver) const {
     // First find all the declarations that happened inside extraction zone.
     llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
     for (auto *RootStmt : RootStmts) {
diff --git a/clang-tools-extra/clangd/unittests/ASTTests.cpp b/clang-tools-extra/clangd/unittests/ASTTests.cpp
index 3101bf34acd717..229c7aa6681c4e 100644
--- a/clang-tools-extra/clangd/unittests/ASTTests.cpp
+++ b/clang-tools-extra/clangd/unittests/ASTTests.cpp
@@ -16,6 +16,7 @@
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/Basic/AttrKinds.h"
 #include "clang/Basic/SourceManager.h"
 #include "llvm/ADT/StringRef.h"
@@ -320,6 +321,90 @@ TEST(ClangdAST, GetOnlyInstantiation) {
   }
 }
 
+TEST(ClangdAST, DISABLED_GetOnlyInstantiationForMemberFunction) {
+  struct {
+    const char *Code;
+    const char *MemberTemplate;
+    const char *ExpectedInstantiation;
+  } Cases[] = {
+      {
+          R"cpp(
+      template <class T> struct Foo {
+        template <class U> struct Bar {
+          U Baz(T);
+        };
+      };
+      double X = Foo<int>::Bar<double>().Baz(3);
+      )cpp",
+          "Baz",
+          "double Baz(int)",
+      },
+      {
+          R"cpp(
+      template <class T> struct Foo {
+        struct Bar {
+          template <class U> U Baz(T);
+        };
+      };
+      double X = Foo<int>::Bar().Baz<double>(3);
+      )cpp",
+          "Baz",
+          "template<> double Baz<double>(int)",
+      },
+      {
+          R"cpp(
+      struct Foo {
+        struct Bar {
+          template <class T> T Baz(T);
+        };
+      };
+      int X = Foo::Bar().Baz(3);
+      )cpp",
+          "Baz",
+          "template<> int Baz<int>(int)",
+      },
+      {
+          R"cpp(
+      struct Foo {
+        template <class T> struct Bar {
+          template <class U> static U Baz(T);
+        };
+      };
+      int X = Foo::Bar<int>::Baz<double>(3);
+      )cpp",
+          "Baz",
+          "template<> static double Baz<double>(int)",
+      },
+
+  };
+  for (const auto &Case : Cases) {
+    SCOPED_TRACE(Case.Code);
+    auto TU = TestTU::withCode(Case.Code);
+    TU.ExtraArgs.push_back("-std=c++20");
+    auto AST = TU.build();
+    PrintingPolicy PP = AST.getASTContext().getPrintingPolicy();
+    PP.TerseOutput = true;
+    std::string Name;
+    if (auto *Result = getOnlyInstantiationForMemberFunction(
+            dyn_cast_if_present<CXXMethodDecl>(
+                &findDecl(AST, [&](const NamedDecl &D) {
+                  auto *MD = dyn_cast<CXXMethodDecl>(&D);
+                  IdentifierInfo *Id = D.getIdentifier();
+                  if (!MD || !Id)
+                    return false;
+                  return MD->isDependentContext() &&
+                         Id->getName() == Case.MemberTemplate;
+                })))) {
+      llvm::raw_string_ostream OS(Name);
+      Result->print(OS, PP);
+    }
+    if (Case.ExpectedInstantiation)
+      EXPECT_EQ(Case.ExpectedInstantiation, Name);
+    else
+      EXPECT_THAT(Name, IsEmpty());
+  }
+}
+
 TEST(ClangdAST, GetContainedAutoParamType) {
   auto TU = TestTU::withCode(R"cpp(
     int withAuto(
diff --git a/clang-tools-extra/clangd/unittests/FindTargetTests.cpp b/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
index fbd10c4a47a793..e4bb157b7b3d0a 100644
--- a/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
+++ b/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
@@ -86,8 +86,8 @@ class TargetDeclTest : public ::testing::Test {
     EXPECT_EQ(N->kind(), NodeType) << Selection;
 
     std::vector<PrintedDecl> ActualDecls;
-    for (const auto &Entry :
-         allTargetDecls(N->ASTNode, AST.getHeuristicResolver()))
+    for (const auto &Entry : allTargetDecls(
+             N->ASTNode, AST.getHeuristicResolver(&N->getDeclContext())))
       ActualDecls.emplace_back(Entry.first, Entry.second);
     return ActualDecls;
   }
diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp
index ead24dec575de0..1c800ea7916137 100644
--- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp
+++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp
@@ -1222,52 +1222,40 @@ TEST(LocateSymbol, TextualSmoke) {
                         hasID(getSymbolID(&findDecl(AST, "MyClass"))))));
 }
 
-TEST(LocateSymbol, DeduceDependentTypeFromSingleInstantiation) {
+TEST(LocateSymbol, DISABLED_DeduceDependentTypeFromSingleInstantiation) {
   Annotations T(R"cpp(
-    struct WildCat {
-      void $wild_meow[[meow]]();
+    struct Widget {
+      int $range_1[[method]](int);
     };
-
-    struct DomesticCat {
-      void $domestic_meow[[meow]]();
-    };
-
-    template <typename Ours>
-    struct Human {
-      template <typename Others>
-      void feed(Others O) {
-        O.me$1^ow();
-        Others Child;
-        Child.me$2^ow();
-        // FIXME: Others().me^ow();
-        Ours Baby;
-        Baby.me$3^ow();
-        // struct Inner {
-        //   Ours Pet;
-        // };
-        // Inner().Pet.me^ow();
-        auto Lambda = [](auto C) {
-          C.me$4^ow();
-        };
-        Lambda(Others());
-      }
+    template <class T> struct A {
+      template <class U> struct B {
+        template <class V> T foo(U, V arg) {
+          V copy;
+          int not_used = copy.$point_1^method(T{});
+          not_used = V().$point_2^method(T{});
+          auto lambda = [](auto w) {
+            return w.$point_3^method(T{});
+          };
+          lambda(copy);
+          arg.$point_4^method(T{});
+        }
+      };
     };
-
-    void foo() {
-      Human<DomesticCat>().feed(WildCat());
+    int main() {
+      int X = A<int>::B<double>().foo(3.14, Widget{});
     }
   )cpp");
 
   auto TU = TestTU::withCode(T.code());
   auto AST = TU.build();
-  EXPECT_THAT(locateSymbolAt(AST, T.point("1")),
-              ElementsAre(sym("meow", T.range("wild_meow"), std::nullopt)));
-  EXPECT_THAT(locateSymbolAt(AST, T.point("2")),
-              ElementsAre(sym("meow", T.range("wild_meow"), std::nullopt)));
-  EXPECT_THAT(locateSymbolAt(AST, T.point("3")),
-              ElementsAre(sym("meow", T.range("domestic_meow"), std::nullopt)));
-  EXPECT_THAT(locateSymbolAt(AST, T.point("4")),
-              ElementsAre(sym("meow", T.range("wild_meow"), std::nullopt)));
+  EXPECT_THAT(locateSymbolAt(AST, T.point("point_1")),
+              ElementsAre(sym("method", T.range("range_1"), std::nullopt)));
+  EXPECT_THAT(locateSymbolAt(AST, T.point("point_2")),
+              ElementsAre(sym("method", T.range("range_1"), std::nullopt)));
+  EXPECT_THAT(locateSymbolAt(AST, T.point("point_3")),
+              ElementsAre(sym("method", T.range("range_1"), std::nullopt)));
+  EXPECT_THAT(locateSymbolAt(AST, T.point("point_4")),
+              ElementsAre(sym("method", T.range("range_1"), std::nullopt)));
 }
 
 TEST(LocateSymbol, Textual) {



More information about the cfe-commits mailing list