[clang] [WIP] Improve HeuristicResolver further so it can replace most of getApproximateType() in SemaCodeComplete (PR #151643)

Nathan Ridge via cfe-commits cfe-commits at lists.llvm.org
Thu Jul 31 22:59:53 PDT 2025


https://github.com/HighCommander4 created https://github.com/llvm/llvm-project/pull/151643

None

>From 4e0c823ce353e8dbfc441e4fb48dcddb37677b53 Mon Sep 17 00:00:00 2001
From: Nathan Ridge <zeratul976 at hotmail.com>
Date: Mon, 7 Jul 2025 09:01:12 -0400
Subject: [PATCH] [WIP] Improve HeuristicResolver further so it can replace
 most of getApproximateType() in SemaCodeComplete

---
 clang/include/clang/Sema/HeuristicResolver.h  |  6 +-
 clang/lib/Sema/HeuristicResolver.cpp          | 93 +++++++++++++-----
 clang/lib/Sema/SemaCodeComplete.cpp           | 97 ++----------------
 clang/test/CodeCompletion/member-access.cpp   |  2 +-
 .../unittests/Sema/HeuristicResolverTest.cpp  | 98 ++++++++++++++++++-
 5 files changed, 175 insertions(+), 121 deletions(-)

diff --git a/clang/include/clang/Sema/HeuristicResolver.h b/clang/include/clang/Sema/HeuristicResolver.h
index e193c0bc14cd9..ef26f7b43ce76 100644
--- a/clang/include/clang/Sema/HeuristicResolver.h
+++ b/clang/include/clang/Sema/HeuristicResolver.h
@@ -54,8 +54,6 @@ class HeuristicResolver {
   std::vector<const NamedDecl *>
   resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE) const;
   std::vector<const NamedDecl *>
-  resolveTypeOfCallExpr(const CallExpr *CE) const;
-  std::vector<const NamedDecl *>
   resolveCalleeOfCallExpr(const CallExpr *CE) const;
   std::vector<const NamedDecl *>
   resolveUsingValueDecl(const UnresolvedUsingValueDecl *UUVD) const;
@@ -94,6 +92,10 @@ class HeuristicResolver {
   // during simplification, and the operation fails if no pointer type is found.
   QualType simplifyType(QualType Type, const Expr *E, bool UnwrapPointer);
 
+  // Try to heuristically resolve the type of a possibly-dependent expression
+  // `E`.
+  QualType resolveExprToType(const Expr *E) const;
+
   // Given an expression `Fn` representing the callee in a function call,
   // if the call is through a function pointer, try to find the declaration of
   // the corresponding function pointer type, so that we can recover argument
diff --git a/clang/lib/Sema/HeuristicResolver.cpp b/clang/lib/Sema/HeuristicResolver.cpp
index 6874d30516f8f..bf9374b7bcf1d 100644
--- a/clang/lib/Sema/HeuristicResolver.cpp
+++ b/clang/lib/Sema/HeuristicResolver.cpp
@@ -36,7 +36,6 @@ class HeuristicResolverImpl {
   resolveMemberExpr(const CXXDependentScopeMemberExpr *ME);
   std::vector<const NamedDecl *>
   resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE);
-  std::vector<const NamedDecl *> resolveTypeOfCallExpr(const CallExpr *CE);
   std::vector<const NamedDecl *> resolveCalleeOfCallExpr(const CallExpr *CE);
   std::vector<const NamedDecl *>
   resolveUsingValueDecl(const UnresolvedUsingValueDecl *UUVD);
@@ -51,6 +50,7 @@ class HeuristicResolverImpl {
                       llvm::function_ref<bool(const NamedDecl *ND)> Filter);
   TagDecl *resolveTypeToTagDecl(QualType T);
   QualType simplifyType(QualType Type, const Expr *E, bool UnwrapPointer);
+  QualType resolveExprToType(const Expr *E);
   FunctionProtoTypeLoc getFunctionProtoTypeLoc(const Expr *Fn);
 
 private:
@@ -72,10 +72,8 @@ class HeuristicResolverImpl {
   resolveDependentMember(QualType T, DeclarationName Name,
                          llvm::function_ref<bool(const NamedDecl *ND)> Filter);
 
-  // Try to heuristically resolve the type of a possibly-dependent expression
-  // `E`.
-  QualType resolveExprToType(const Expr *E);
   std::vector<const NamedDecl *> resolveExprToDecls(const Expr *E);
+  QualType resolveTypeOfCallExpr(const CallExpr *CE);
 
   bool findOrdinaryMemberInDependentClasses(const CXXBaseSpecifier *Specifier,
                                             CXXBasePath &Path,
@@ -97,19 +95,26 @@ const auto TemplateFilter = [](const NamedDecl *D) {
   return isa<TemplateDecl>(D);
 };
 
-QualType resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
-                            ASTContext &Ctx) {
-  if (Decls.size() != 1) // Names an overload set -- just bail.
-    return QualType();
-  if (const auto *TD = dyn_cast<TypeDecl>(Decls[0])) {
+QualType resolveDeclToType(const NamedDecl *D, ASTContext &Ctx) {
+  if (const auto *TempD = dyn_cast<TemplateDecl>(D)) {
+    D = TempD->getTemplatedDecl();
+  }
+  if (const auto *TD = dyn_cast<TypeDecl>(D)) {
     return Ctx.getTypeDeclType(TD);
   }
-  if (const auto *VD = dyn_cast<ValueDecl>(Decls[0])) {
+  if (const auto *VD = dyn_cast<ValueDecl>(D)) {
     return VD->getType();
   }
   return QualType();
 }
 
+QualType resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
+                            ASTContext &Ctx) {
+  if (Decls.size() != 1) // Names an overload set -- just bail.
+    return QualType();
+  return resolveDeclToType(Decls[0], Ctx);
+}
+
 TemplateName getReferencedTemplateName(const Type *T) {
   if (const auto *TST = T->getAs<TemplateSpecializationType>()) {
     return TST->getTemplateName();
@@ -322,19 +327,29 @@ HeuristicResolverImpl::resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE) {
   return resolveDependentMember(Qualifier, RE->getDeclName(), StaticFilter);
 }
 
-std::vector<const NamedDecl *>
-HeuristicResolverImpl::resolveTypeOfCallExpr(const CallExpr *CE) {
-  QualType CalleeType = resolveExprToType(CE->getCallee());
-  if (CalleeType.isNull())
-    return {};
-  if (const auto *FnTypePtr = CalleeType->getAs<PointerType>())
-    CalleeType = FnTypePtr->getPointeeType();
-  if (const FunctionType *FnType = CalleeType->getAs<FunctionType>()) {
-    if (const auto *D = resolveTypeToTagDecl(FnType->getReturnType())) {
-      return {D};
+QualType HeuristicResolverImpl::resolveTypeOfCallExpr(const CallExpr *CE) {
+  // resolveExprToType(CE->getCallee()) would bail in the case of multiple
+  // overloads, as it can't produce a single type for them. We can be more
+  // permissive here, and allow multiple overloads with a common return type.
+  std::vector<const NamedDecl *> CalleeDecls =
+      resolveExprToDecls(CE->getCallee());
+  QualType CommonReturnType;
+  for (const NamedDecl *CalleeDecl : CalleeDecls) {
+    QualType CalleeType = resolveDeclToType(CalleeDecl, Ctx);
+    if (CalleeType.isNull())
+      continue;
+    if (const auto *FnTypePtr = CalleeType->getAs<PointerType>())
+      CalleeType = FnTypePtr->getPointeeType();
+    if (const FunctionType *FnType = CalleeType->getAs<FunctionType>()) {
+      QualType ReturnType =
+          simplifyType(FnType->getReturnType(), nullptr, false);
+      if (!CommonReturnType.isNull() && CommonReturnType != ReturnType) {
+        return {}; // conflicting return types
+      }
+      CommonReturnType = ReturnType;
     }
   }
-  return {};
+  return CommonReturnType;
 }
 
 std::vector<const NamedDecl *>
@@ -382,15 +397,41 @@ HeuristicResolverImpl::resolveExprToDecls(const Expr *E) {
     return {OE->decls_begin(), OE->decls_end()};
   }
   if (const auto *CE = dyn_cast<CallExpr>(E)) {
-    return resolveTypeOfCallExpr(CE);
+    QualType T = resolveTypeOfCallExpr(CE);
+    if (const auto *D = resolveTypeToTagDecl(T)) {
+      return {D};
+    }
+    return {};
   }
   if (const auto *ME = dyn_cast<MemberExpr>(E))
     return {ME->getMemberDecl()};
+  if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
+    return {DRE->getDecl()};
 
   return {};
 }
 
 QualType HeuristicResolverImpl::resolveExprToType(const Expr *E) {
+  // resolveExprToDecls on a CallExpr only succeeds if the return type is
+  // a TagDecl, but we may want the type of a call in other cases as well.
+  // (FIXME: There are probably other cases where we can do something more
+  // flexible than resoveExprToDecls + resolveDeclsToType, e.g. in the case
+  // of OverloadExpr we can probably accept overloads with a common type).
+  if (const auto *CE = dyn_cast<CallExpr>(E)) {
+    if (QualType Resolved = resolveTypeOfCallExpr(CE); !Resolved.isNull())
+      return Resolved;
+  }
+  // Similarly, unwrapping a unary dereference operation does not work via
+  // resolveExprToDecls.
+  if (const auto *UO = dyn_cast<UnaryOperator>(E->IgnoreParenCasts())) {
+    if (UO->getOpcode() == UnaryOperatorKind::UO_Deref) {
+      if (auto Pointee = getPointeeType(resolveExprToType(UO->getSubExpr()));
+          !Pointee.isNull()) {
+        return Pointee;
+      }
+    }
+  }
+
   std::vector<const NamedDecl *> Decls = resolveExprToDecls(E);
   if (!Decls.empty())
     return resolveDeclsToType(Decls, Ctx);
@@ -569,10 +610,6 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveDeclRefExpr(
   return HeuristicResolverImpl(Ctx).resolveDeclRefExpr(RE);
 }
 std::vector<const NamedDecl *>
-HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const {
-  return HeuristicResolverImpl(Ctx).resolveTypeOfCallExpr(CE);
-}
-std::vector<const NamedDecl *>
 HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const {
   return HeuristicResolverImpl(Ctx).resolveCalleeOfCallExpr(CE);
 }
@@ -608,7 +645,9 @@ QualType HeuristicResolver::simplifyType(QualType Type, const Expr *E,
                                          bool UnwrapPointer) {
   return HeuristicResolverImpl(Ctx).simplifyType(Type, E, UnwrapPointer);
 }
-
+QualType HeuristicResolver::resolveExprToType(const Expr *E) const {
+  return HeuristicResolverImpl(Ctx).resolveExprToType(E);
+}
 FunctionProtoTypeLoc
 HeuristicResolver::getFunctionProtoTypeLoc(const Expr *Fn) const {
   return HeuristicResolverImpl(Ctx).getFunctionProtoTypeLoc(Fn);
diff --git a/clang/lib/Sema/SemaCodeComplete.cpp b/clang/lib/Sema/SemaCodeComplete.cpp
index a43ac9eb7610d..229fa8e1c734c 100644
--- a/clang/lib/Sema/SemaCodeComplete.cpp
+++ b/clang/lib/Sema/SemaCodeComplete.cpp
@@ -5757,96 +5757,13 @@ class ConceptInfo {
 // We accept some lossiness (like dropping parameters).
 // We only try to handle common expressions on the LHS of MemberExpr.
 QualType getApproximateType(const Expr *E, HeuristicResolver &Resolver) {
-  if (E->getType().isNull())
-    return QualType();
-  // Don't drop implicit cast if it's an array decay.
-  if (auto *ICE = dyn_cast<ImplicitCastExpr>(E);
-      !ICE || ICE->getCastKind() != CK_ArrayToPointerDecay)
-    E = E->IgnoreParenImpCasts();
-  QualType Unresolved = E->getType();
-  // Resolve DependentNameType
-  if (const auto *DNT = Unresolved->getAs<DependentNameType>()) {
-    if (auto Decls = Resolver.resolveDependentNameType(DNT);
-        Decls.size() == 1) {
-      if (const auto *TD = dyn_cast<TypeDecl>(Decls[0]))
-        return QualType(TD->getTypeForDecl(), 0);
-    }
-  }
-  // We only resolve DependentTy, or undeduced autos (including auto* etc).
-  if (!Unresolved->isSpecificBuiltinType(BuiltinType::Dependent)) {
-    AutoType *Auto = Unresolved->getContainedAutoType();
-    if (!Auto || !Auto->isUndeducedAutoType())
-      return Unresolved;
-  }
-  // A call: approximate-resolve callee to a function type, get its return type
-  if (const CallExpr *CE = llvm::dyn_cast<CallExpr>(E)) {
-    QualType Callee = getApproximateType(CE->getCallee(), Resolver);
-    if (Callee.isNull() ||
-        Callee->isSpecificPlaceholderType(BuiltinType::BoundMember))
-      Callee = Expr::findBoundMemberType(CE->getCallee());
-    if (Callee.isNull())
-      return Unresolved;
-
-    if (const auto *FnTypePtr = Callee->getAs<PointerType>()) {
-      Callee = FnTypePtr->getPointeeType();
-    } else if (const auto *BPT = Callee->getAs<BlockPointerType>()) {
-      Callee = BPT->getPointeeType();
-    }
-    if (const FunctionType *FnType = Callee->getAs<FunctionType>())
-      return FnType->getReturnType().getNonReferenceType();
-
-    // Unresolved call: try to guess the return type.
-    if (const auto *OE = llvm::dyn_cast<OverloadExpr>(CE->getCallee())) {
-      // If all candidates have the same approximate return type, use it.
-      // Discard references and const to allow more to be "the same".
-      // (In particular, if there's one candidate + ADL, resolve it).
-      const Type *Common = nullptr;
-      for (const auto *D : OE->decls()) {
-        QualType ReturnType;
-        if (const auto *FD = llvm::dyn_cast<FunctionDecl>(D))
-          ReturnType = FD->getReturnType();
-        else if (const auto *FTD = llvm::dyn_cast<FunctionTemplateDecl>(D))
-          ReturnType = FTD->getTemplatedDecl()->getReturnType();
-        if (ReturnType.isNull())
-          continue;
-        const Type *Candidate =
-            ReturnType.getNonReferenceType().getCanonicalType().getTypePtr();
-        if (Common && Common != Candidate)
-          return Unresolved; // Multiple candidates.
-        Common = Candidate;
-      }
-      if (Common != nullptr)
-        return QualType(Common, 0);
-    }
-  }
-  // A dependent member: resolve using HeuristicResolver.
-  if (const auto *CDSME = llvm::dyn_cast<CXXDependentScopeMemberExpr>(E)) {
-    for (const auto *Member : Resolver.resolveMemberExpr(CDSME)) {
-      if (const auto *VD = dyn_cast<ValueDecl>(Member)) {
-        return VD->getType().getNonReferenceType();
-      }
-    }
-  }
-  // A reference to an `auto` variable: approximate-resolve its initializer.
-  if (const auto *DRE = llvm::dyn_cast<DeclRefExpr>(E)) {
-    if (const auto *VD = llvm::dyn_cast<VarDecl>(DRE->getDecl())) {
-      if (VD->hasInit())
-        return getApproximateType(VD->getInit(), Resolver);
-    }
-  }
-  if (const auto *UO = llvm::dyn_cast<UnaryOperator>(E)) {
-    if (UO->getOpcode() == UnaryOperatorKind::UO_Deref) {
-      // We recurse into the subexpression because it could be of dependent
-      // type.
-      if (auto Pointee =
-              getApproximateType(UO->getSubExpr(), Resolver)->getPointeeType();
-          !Pointee.isNull())
-        return Pointee;
-      // Our caller expects a non-null result, even though the SubType is
-      // supposed to have a pointee. Fall through to Unresolved anyway.
-    }
-  }
-  return Unresolved;
+  QualType Result = Resolver.resolveExprToType(E);
+  if (Result.isNull())
+    return Result;
+  Result = Resolver.simplifyType(Result.getNonReferenceType(), E, false);
+  if (Result.isNull())
+    return Result;
+  return Result.getNonReferenceType();
 }
 
 // If \p Base is ParenListExpr, assume a chain of comma operators and pick the
diff --git a/clang/test/CodeCompletion/member-access.cpp b/clang/test/CodeCompletion/member-access.cpp
index 8526ed7273474..d97ddad78607c 100644
--- a/clang/test/CodeCompletion/member-access.cpp
+++ b/clang/test/CodeCompletion/member-access.cpp
@@ -28,7 +28,7 @@ class Proxy {
 };
 
 void test(const Proxy &p) {
-  p->
+  p-> 
 }
 
 struct Test1 {
diff --git a/clang/unittests/Sema/HeuristicResolverTest.cpp b/clang/unittests/Sema/HeuristicResolverTest.cpp
index ee434f7a1d43a..fe975a68be13c 100644
--- a/clang/unittests/Sema/HeuristicResolverTest.cpp
+++ b/clang/unittests/Sema/HeuristicResolverTest.cpp
@@ -198,7 +198,6 @@ TEST(HeuristicResolver, MemberExpr_AutoTypeDeduction2) {
     struct B {
       int waldo;
     };
-
     template <typename T>
     struct A {
       B b;
@@ -233,6 +232,103 @@ TEST(HeuristicResolver, MemberExpr_Chained) {
       cxxMethodDecl(hasName("foo")).bind("output"));
 }
 
+TEST(HeuristicResolver, MemberExpr_Chained_ReferenceType) {
+  std::string Code = R"cpp(
+    struct B {
+      int waldo;
+    };
+    template <typename T>
+    struct A {
+      B &foo();
+    };
+    template <typename T>
+    void bar(A<T> a) {
+      a.foo().waldo;
+    }
+  )cpp";
+  // Test resolution of "waldo" in "a.foo().waldo"
+  expectResolution(
+      Code, &HeuristicResolver::resolveMemberExpr,
+      cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
+      fieldDecl(hasName("waldo")).bind("output"));
+}
+
+TEST(HeuristicResolver, MemberExpr_Chained_PointerArrow) {
+  std::string Code = R"cpp(
+    struct B {
+      int waldo;
+    };
+    template <typename T>
+    B* foo(T);
+    template <class T>
+    void bar(T t) {
+      foo(t)->waldo;
+    }
+  )cpp";
+  // Test resolution of "waldo" in "foo(t)->waldo"
+  expectResolution(
+      Code, &HeuristicResolver::resolveMemberExpr,
+      cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
+      fieldDecl(hasName("waldo")).bind("output"));
+}
+
+TEST(HeuristicResolver, MemberExpr_Chained_PointerDeref) {
+  std::string Code = R"cpp(
+    struct B {
+      int waldo;
+    };
+    template <typename T>
+    B* foo(T);
+    template <class T>
+    void bar(T t) {
+      (*foo(t)).waldo;
+    }
+  )cpp";
+  // Test resolution of "waldo" in "foo(t)->waldo"
+  expectResolution(
+      Code, &HeuristicResolver::resolveMemberExpr,
+      cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
+      fieldDecl(hasName("waldo")).bind("output"));
+}
+
+TEST(HeuristicResolver, MemberExpr_Chained_Overload) {
+  std::string Code = R"cpp(
+    struct B {
+      int waldo;
+    };
+    B overloaded(int);
+    B overloaded(double);
+    template <typename T>
+    void foo(T t) {
+      overloaded(t).waldo;
+    }
+  )cpp";
+  // Test resolution of "waldo" in "overloaded(t).waldo"
+  expectResolution(
+      Code, &HeuristicResolver::resolveMemberExpr,
+      cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
+      fieldDecl(hasName("waldo")).bind("output"));
+}
+
+TEST(HeuristicResolver, MemberExpr_CallToFunctionTemplate) {
+  std::string Code = R"cpp(
+    struct B {
+      int waldo;
+    };
+    template <typename T>
+    B bar(T);
+    template <typename T>
+    void foo(T t) {
+      bar(t).waldo;
+    }
+  )cpp";
+  // Test resolution of "waldo" in "bar(t).waldo"
+  expectResolution(
+      Code, &HeuristicResolver::resolveMemberExpr,
+      cxxDependentScopeMemberExpr(hasMemberName("waldo")).bind("input"),
+      fieldDecl(hasName("waldo")).bind("output"));
+}
+
 TEST(HeuristicResolver, MemberExpr_ReferenceType) {
   std::string Code = R"cpp(
     struct B {



More information about the cfe-commits mailing list