[clang-tools-extra] [clangd] [HeuristicResolver] Protect against infinite recursion on DependentNameTypes (PR #83542)

Nathan Ridge via cfe-commits cfe-commits at lists.llvm.org
Sun Mar 3 20:56:12 PST 2024


https://github.com/HighCommander4 updated https://github.com/llvm/llvm-project/pull/83542

>From 1495cd55329d9d9d6ef7f61a9dc3f3cc48a72dac Mon Sep 17 00:00:00 2001
From: Nathan Ridge <zeratul976 at hotmail.com>
Date: Fri, 1 Mar 2024 03:27:51 -0500
Subject: [PATCH] [clangd] [HeuristicResolver] Protect against infinite
 recursion on DependentNameTypes

When resolving names inside templates that implement recursive
compile-time functions (e.g. waldo<N>::type is defined in terms
of waldo<N-1>::type), HeuristicResolver could get into an infinite
recursion, specifically one where resolveDependentNameType() can
be called recursively with the same DependentNameType*.

To guard against this, HeuristicResolver tracks, for each external
call into a HeuristicResolver function, the set of DependentNameTypes
that it has seen, and bails if it sees the same DependentNameType again.

To implement this, a helper class HeuristicResolverImpl is introduced
to store state that persists for the duration of an external call into
HeuristicResolver (but does not persist between such calls).

Fixes https://github.com/clangd/clangd/issues/1951
---
 .../clangd/HeuristicResolver.cpp              | 170 ++++++++++++++----
 clang-tools-extra/clangd/HeuristicResolver.h  |  37 ----
 .../clangd/unittests/FindTargetTests.cpp      |  27 +++
 3 files changed, 165 insertions(+), 69 deletions(-)

diff --git a/clang-tools-extra/clangd/HeuristicResolver.cpp b/clang-tools-extra/clangd/HeuristicResolver.cpp
index 3c147b6b582bf0..26d54200eeffd2 100644
--- a/clang-tools-extra/clangd/HeuristicResolver.cpp
+++ b/clang-tools-extra/clangd/HeuristicResolver.cpp
@@ -16,6 +16,80 @@
 namespace clang {
 namespace clangd {
 
+namespace {
+
+// Helper class for implementing HeuristicResolver.
+// Unlike HeuristicResolver which is a long-lived class,
+// a new instance of this class is created for every external
+// call into a HeuristicResolver operation. That allows this
+// class to store state that's local to such a top-level call,
+// particularly "recursion protection sets" that keep track of
+// nodes that have already been seen to avoid infinite recursion.
+class HeuristicResolverImpl {
+public:
+  HeuristicResolverImpl(ASTContext &Ctx) : Ctx(Ctx) {}
+
+  // These functions match the public interface of HeuristicResolver
+  // (but aren't const since they may modify the recursion protection sets).
+  std::vector<const NamedDecl *>
+  resolveMemberExpr(const CXXDependentScopeMemberExpr *ME);
+  std::vector<const NamedDecl *>
+  resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE);
+  std::vector<const NamedDecl *> resolveTypeOfCallExpr(const CallExpr *CE);
+  std::vector<const NamedDecl *> resolveCalleeOfCallExpr(const CallExpr *CE);
+  std::vector<const NamedDecl *>
+  resolveUsingValueDecl(const UnresolvedUsingValueDecl *UUVD);
+  std::vector<const NamedDecl *>
+  resolveDependentNameType(const DependentNameType *DNT);
+  std::vector<const NamedDecl *> resolveTemplateSpecializationType(
+      const DependentTemplateSpecializationType *DTST);
+  const Type *resolveNestedNameSpecifierToType(const NestedNameSpecifier *NNS);
+  const Type *getPointeeType(const Type *T);
+
+private:
+  ASTContext &Ctx;
+
+  // Recursion protection sets
+  llvm::SmallSet<const DependentNameType *, 4> SeenDependentNameTypes;
+
+  // Given a tag-decl type and a member name, heuristically resolve the
+  // name to one or more declarations.
+  // The current heuristic is simply to look up the name in the primary
+  // template. This is a heuristic because the template could potentially
+  // have specializations that declare different members.
+  // Multiple declarations could be returned if the name is overloaded
+  // (e.g. an overloaded method in the primary template).
+  // This heuristic will give the desired answer in many cases, e.g.
+  // for a call to vector<T>::size().
+  std::vector<const NamedDecl *>
+  resolveDependentMember(const Type *T, DeclarationName Name,
+                         llvm::function_ref<bool(const NamedDecl *ND)> Filter);
+
+  // Try to heuristically resolve the type of a possibly-dependent expression
+  // `E`.
+  const Type *resolveExprToType(const Expr *E);
+  std::vector<const NamedDecl *> resolveExprToDecls(const Expr *E);
+
+  // Helper function for HeuristicResolver::resolveDependentMember()
+  // which takes a possibly-dependent type `T` and heuristically
+  // resolves it to a CXXRecordDecl in which we can try name lookup.
+  CXXRecordDecl *resolveTypeToRecordDecl(const Type *T);
+
+  // This is a reimplementation of CXXRecordDecl::lookupDependentName()
+  // so that the implementation can call into other HeuristicResolver helpers.
+  // FIXME: Once HeuristicResolver is upstreamed to the clang libraries
+  // (https://github.com/clangd/clangd/discussions/1662),
+  // CXXRecordDecl::lookupDepenedentName() can be removed, and its call sites
+  // can be modified to benefit from the more comprehensive heuristics offered
+  // by HeuristicResolver instead.
+  std::vector<const NamedDecl *>
+  lookupDependentName(CXXRecordDecl *RD, DeclarationName Name,
+                      llvm::function_ref<bool(const NamedDecl *ND)> Filter);
+  bool findOrdinaryMemberInDependentClasses(const CXXBaseSpecifier *Specifier,
+                                            CXXBasePath &Path,
+                                            DeclarationName Name);
+};
+
 // Convenience lambdas for use as the 'Filter' parameter of
 // HeuristicResolver::resolveDependentMember().
 const auto NoFilter = [](const NamedDecl *D) { return true; };
@@ -31,8 +105,6 @@ const auto TemplateFilter = [](const NamedDecl *D) {
   return isa<TemplateDecl>(D);
 };
 
-namespace {
-
 const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
                                ASTContext &Ctx) {
   if (Decls.size() != 1) // Names an overload set -- just bail.
@@ -46,12 +118,10 @@ const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
   return nullptr;
 }
 
-} // namespace
-
 // Helper function for HeuristicResolver::resolveDependentMember()
 // which takes a possibly-dependent type `T` and heuristically
 // resolves it to a CXXRecordDecl in which we can try name lookup.
-CXXRecordDecl *HeuristicResolver::resolveTypeToRecordDecl(const Type *T) const {
+CXXRecordDecl *HeuristicResolverImpl::resolveTypeToRecordDecl(const Type *T) {
   assert(T);
 
   // Unwrap type sugar such as type aliases.
@@ -84,7 +154,7 @@ CXXRecordDecl *HeuristicResolver::resolveTypeToRecordDecl(const Type *T) const {
   return TD->getTemplatedDecl();
 }
 
-const Type *HeuristicResolver::getPointeeType(const Type *T) const {
+const Type *HeuristicResolverImpl::getPointeeType(const Type *T) {
   if (!T)
     return nullptr;
 
@@ -117,8 +187,8 @@ const Type *HeuristicResolver::getPointeeType(const Type *T) const {
   return FirstArg.getAsType().getTypePtrOrNull();
 }
 
-std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
-    const CXXDependentScopeMemberExpr *ME) const {
+std::vector<const NamedDecl *> HeuristicResolverImpl::resolveMemberExpr(
+    const CXXDependentScopeMemberExpr *ME) {
   // If the expression has a qualifier, try resolving the member inside the
   // qualifier's type.
   // Note that we cannot use a NonStaticFilter in either case, for a couple
@@ -164,14 +234,14 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
   return resolveDependentMember(BaseType, ME->getMember(), NoFilter);
 }
 
-std::vector<const NamedDecl *> HeuristicResolver::resolveDeclRefExpr(
-    const DependentScopeDeclRefExpr *RE) const {
+std::vector<const NamedDecl *>
+HeuristicResolverImpl::resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE) {
   return resolveDependentMember(RE->getQualifier()->getAsType(),
                                 RE->getDeclName(), StaticFilter);
 }
 
 std::vector<const NamedDecl *>
-HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const {
+HeuristicResolverImpl::resolveTypeOfCallExpr(const CallExpr *CE) {
   const auto *CalleeType = resolveExprToType(CE->getCallee());
   if (!CalleeType)
     return {};
@@ -187,7 +257,7 @@ HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const {
 }
 
 std::vector<const NamedDecl *>
-HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const {
+HeuristicResolverImpl::resolveCalleeOfCallExpr(const CallExpr *CE) {
   if (const auto *ND = dyn_cast_or_null<NamedDecl>(CE->getCalleeDecl())) {
     return {ND};
   }
@@ -195,29 +265,31 @@ HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const {
   return resolveExprToDecls(CE->getCallee());
 }
 
-std::vector<const NamedDecl *> HeuristicResolver::resolveUsingValueDecl(
-    const UnresolvedUsingValueDecl *UUVD) const {
+std::vector<const NamedDecl *> HeuristicResolverImpl::resolveUsingValueDecl(
+    const UnresolvedUsingValueDecl *UUVD) {
   return resolveDependentMember(UUVD->getQualifier()->getAsType(),
                                 UUVD->getNameInfo().getName(), ValueFilter);
 }
 
-std::vector<const NamedDecl *> HeuristicResolver::resolveDependentNameType(
-    const DependentNameType *DNT) const {
+std::vector<const NamedDecl *>
+HeuristicResolverImpl::resolveDependentNameType(const DependentNameType *DNT) {
+  if (auto [_, inserted] = SeenDependentNameTypes.insert(DNT); !inserted)
+    return {};
   return resolveDependentMember(
       resolveNestedNameSpecifierToType(DNT->getQualifier()),
       DNT->getIdentifier(), TypeFilter);
 }
 
 std::vector<const NamedDecl *>
-HeuristicResolver::resolveTemplateSpecializationType(
-    const DependentTemplateSpecializationType *DTST) const {
+HeuristicResolverImpl::resolveTemplateSpecializationType(
+    const DependentTemplateSpecializationType *DTST) {
   return resolveDependentMember(
       resolveNestedNameSpecifierToType(DTST->getQualifier()),
       DTST->getIdentifier(), TemplateFilter);
 }
 
 std::vector<const NamedDecl *>
-HeuristicResolver::resolveExprToDecls(const Expr *E) const {
+HeuristicResolverImpl::resolveExprToDecls(const Expr *E) {
   if (const auto *ME = dyn_cast<CXXDependentScopeMemberExpr>(E)) {
     return resolveMemberExpr(ME);
   }
@@ -236,7 +308,7 @@ HeuristicResolver::resolveExprToDecls(const Expr *E) const {
   return {};
 }
 
-const Type *HeuristicResolver::resolveExprToType(const Expr *E) const {
+const Type *HeuristicResolverImpl::resolveExprToType(const Expr *E) {
   std::vector<const NamedDecl *> Decls = resolveExprToDecls(E);
   if (!Decls.empty())
     return resolveDeclsToType(Decls, Ctx);
@@ -244,8 +316,8 @@ const Type *HeuristicResolver::resolveExprToType(const Expr *E) const {
   return E->getType().getTypePtr();
 }
 
-const Type *HeuristicResolver::resolveNestedNameSpecifierToType(
-    const NestedNameSpecifier *NNS) const {
+const Type *HeuristicResolverImpl::resolveNestedNameSpecifierToType(
+    const NestedNameSpecifier *NNS) {
   if (!NNS)
     return nullptr;
 
@@ -270,8 +342,6 @@ const Type *HeuristicResolver::resolveNestedNameSpecifierToType(
   return nullptr;
 }
 
-namespace {
-
 bool isOrdinaryMember(const NamedDecl *ND) {
   return ND->isInIdentifierNamespace(Decl::IDNS_Ordinary | Decl::IDNS_Tag |
                                      Decl::IDNS_Member);
@@ -287,11 +357,9 @@ bool findOrdinaryMember(const CXXRecordDecl *RD, CXXBasePath &Path,
   return false;
 }
 
-} // namespace
-
-bool HeuristicResolver::findOrdinaryMemberInDependentClasses(
+bool HeuristicResolverImpl::findOrdinaryMemberInDependentClasses(
     const CXXBaseSpecifier *Specifier, CXXBasePath &Path,
-    DeclarationName Name) const {
+    DeclarationName Name) {
   CXXRecordDecl *RD =
       resolveTypeToRecordDecl(Specifier->getType().getTypePtr());
   if (!RD)
@@ -299,9 +367,9 @@ bool HeuristicResolver::findOrdinaryMemberInDependentClasses(
   return findOrdinaryMember(RD, Path, Name);
 }
 
-std::vector<const NamedDecl *> HeuristicResolver::lookupDependentName(
+std::vector<const NamedDecl *> HeuristicResolverImpl::lookupDependentName(
     CXXRecordDecl *RD, DeclarationName Name,
-    llvm::function_ref<bool(const NamedDecl *ND)> Filter) const {
+    llvm::function_ref<bool(const NamedDecl *ND)> Filter) {
   std::vector<const NamedDecl *> Results;
 
   // Lookup in the class.
@@ -332,9 +400,9 @@ std::vector<const NamedDecl *> HeuristicResolver::lookupDependentName(
   return Results;
 }
 
-std::vector<const NamedDecl *> HeuristicResolver::resolveDependentMember(
+std::vector<const NamedDecl *> HeuristicResolverImpl::resolveDependentMember(
     const Type *T, DeclarationName Name,
-    llvm::function_ref<bool(const NamedDecl *ND)> Filter) const {
+    llvm::function_ref<bool(const NamedDecl *ND)> Filter) {
   if (!T)
     return {};
   if (auto *ET = T->getAs<EnumType>()) {
@@ -349,6 +417,44 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveDependentMember(
   }
   return {};
 }
+} // namespace
+
+std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
+    const CXXDependentScopeMemberExpr *ME) const {
+  return HeuristicResolverImpl(Ctx).resolveMemberExpr(ME);
+}
+std::vector<const NamedDecl *> HeuristicResolver::resolveDeclRefExpr(
+    const DependentScopeDeclRefExpr *RE) const {
+  return HeuristicResolverImpl(Ctx).resolveDeclRefExpr(RE);
+}
+std::vector<const NamedDecl *>
+HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const {
+  return HeuristicResolverImpl(Ctx).resolveTypeOfCallExpr(CE);
+}
+std::vector<const NamedDecl *>
+HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const {
+  return HeuristicResolverImpl(Ctx).resolveCalleeOfCallExpr(CE);
+}
+std::vector<const NamedDecl *> HeuristicResolver::resolveUsingValueDecl(
+    const UnresolvedUsingValueDecl *UUVD) const {
+  return HeuristicResolverImpl(Ctx).resolveUsingValueDecl(UUVD);
+}
+std::vector<const NamedDecl *> HeuristicResolver::resolveDependentNameType(
+    const DependentNameType *DNT) const {
+  return HeuristicResolverImpl(Ctx).resolveDependentNameType(DNT);
+}
+std::vector<const NamedDecl *>
+HeuristicResolver::resolveTemplateSpecializationType(
+    const DependentTemplateSpecializationType *DTST) const {
+  return HeuristicResolverImpl(Ctx).resolveTemplateSpecializationType(DTST);
+}
+const Type *HeuristicResolver::resolveNestedNameSpecifierToType(
+    const NestedNameSpecifier *NNS) const {
+  return HeuristicResolverImpl(Ctx).resolveNestedNameSpecifierToType(NNS);
+}
+const Type *HeuristicResolver::getPointeeType(const Type *T) const {
+  return HeuristicResolverImpl(Ctx).getPointeeType(T);
+}
 
 } // namespace clangd
 } // namespace clang
diff --git a/clang-tools-extra/clangd/HeuristicResolver.h b/clang-tools-extra/clangd/HeuristicResolver.h
index dc04123d37593c..dcc063bbc4adc0 100644
--- a/clang-tools-extra/clangd/HeuristicResolver.h
+++ b/clang-tools-extra/clangd/HeuristicResolver.h
@@ -77,43 +77,6 @@ class HeuristicResolver {
 
 private:
   ASTContext &Ctx;
-
-  // Given a tag-decl type and a member name, heuristically resolve the
-  // name to one or more declarations.
-  // The current heuristic is simply to look up the name in the primary
-  // template. This is a heuristic because the template could potentially
-  // have specializations that declare different members.
-  // Multiple declarations could be returned if the name is overloaded
-  // (e.g. an overloaded method in the primary template).
-  // This heuristic will give the desired answer in many cases, e.g.
-  // for a call to vector<T>::size().
-  std::vector<const NamedDecl *> resolveDependentMember(
-      const Type *T, DeclarationName Name,
-      llvm::function_ref<bool(const NamedDecl *ND)> Filter) const;
-
-  // Try to heuristically resolve the type of a possibly-dependent expression
-  // `E`.
-  const Type *resolveExprToType(const Expr *E) const;
-  std::vector<const NamedDecl *> resolveExprToDecls(const Expr *E) const;
-
-  // Helper function for HeuristicResolver::resolveDependentMember()
-  // which takes a possibly-dependent type `T` and heuristically
-  // resolves it to a CXXRecordDecl in which we can try name lookup.
-  CXXRecordDecl *resolveTypeToRecordDecl(const Type *T) const;
-
-  // This is a reimplementation of CXXRecordDecl::lookupDependentName()
-  // so that the implementation can call into other HeuristicResolver helpers.
-  // FIXME: Once HeuristicResolver is upstreamed to the clang libraries
-  // (https://github.com/clangd/clangd/discussions/1662),
-  // CXXRecordDecl::lookupDepenedentName() can be removed, and its call sites
-  // can be modified to benefit from the more comprehensive heuristics offered
-  // by HeuristicResolver instead.
-  std::vector<const NamedDecl *> lookupDependentName(
-      CXXRecordDecl *RD, DeclarationName Name,
-      llvm::function_ref<bool(const NamedDecl *ND)> Filter) const;
-  bool findOrdinaryMemberInDependentClasses(const CXXBaseSpecifier *Specifier,
-                                            CXXBasePath &Path,
-                                            DeclarationName Name) const;
 };
 
 } // namespace clangd
diff --git a/clang-tools-extra/clangd/unittests/FindTargetTests.cpp b/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
index 29cff68cf03b2e..0af6036734ba53 100644
--- a/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
+++ b/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
@@ -1009,6 +1009,33 @@ TEST_F(TargetDeclTest, DependentTypes) {
       )cpp";
   EXPECT_DECLS("DependentTemplateSpecializationTypeLoc",
                "template <typename> struct B");
+
+  // Dependent name with recursive definition. We don't expect a
+  // result, but we shouldn't get into a stack overflow either.
+  Code = R"cpp(
+        template <int N>
+        struct waldo {
+          typedef typename waldo<N - 1>::type::[[next]] type;
+        };
+  )cpp";
+  EXPECT_DECLS("DependentNameTypeLoc");
+
+  // Similar to above but using mutually recursive templates.
+  Code = R"cpp(
+        template <int N>
+        struct odd;
+
+        template <int N>
+        struct even {
+          using type = typename odd<N - 1>::type::next;
+        };
+
+        template <int N>
+        struct odd {
+          using type = typename even<N - 1>::type::[[next]];
+        };
+  )cpp";
+  EXPECT_DECLS("DependentNameTypeLoc");
 }
 
 TEST_F(TargetDeclTest, TypedefCascade) {



More information about the cfe-commits mailing list