[clang-tools-extra] 71a082f - [clangd] Implement textDocument/typeDefinition

Sam McCall via cfe-commits cfe-commits at lists.llvm.org
Thu Jan 13 13:15:16 PST 2022


Author: Sam McCall
Date: 2022-01-13T22:15:10+01:00
New Revision: 71a082f72674ceba0797748543dcef3620b22822

URL: https://github.com/llvm/llvm-project/commit/71a082f72674ceba0797748543dcef3620b22822
DIFF: https://github.com/llvm/llvm-project/commit/71a082f72674ceba0797748543dcef3620b22822.diff

LOG: [clangd] Implement textDocument/typeDefinition

This reuses the type=>decl mapping from go-to-definition on auto.
(Which could stand some improvement, but that can happen later).

Fixes https://github.com/clangd/clangd/issues/367

Differential Revision: https://reviews.llvm.org/D116443

Added: 
    clang-tools-extra/clangd/test/type-definition.test

Modified: 
    clang-tools-extra/clangd/ClangdLSPServer.cpp
    clang-tools-extra/clangd/ClangdLSPServer.h
    clang-tools-extra/clangd/ClangdServer.cpp
    clang-tools-extra/clangd/ClangdServer.h
    clang-tools-extra/clangd/XRefs.cpp
    clang-tools-extra/clangd/XRefs.h
    clang-tools-extra/clangd/test/initialize-params.test
    clang-tools-extra/clangd/unittests/XRefsTests.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/clangd/ClangdLSPServer.cpp b/clang-tools-extra/clangd/ClangdLSPServer.cpp
index 47bde7d0edcfd..5b1d046872655 100644
--- a/clang-tools-extra/clangd/ClangdLSPServer.cpp
+++ b/clang-tools-extra/clangd/ClangdLSPServer.cpp
@@ -560,6 +560,7 @@ void ClangdLSPServer::onInitialize(const InitializeParams &Params,
       {"declarationProvider", true},
       {"definitionProvider", true},
       {"implementationProvider", true},
+      {"typeDefinitionProvider", true},
       {"documentHighlightProvider", true},
       {"documentLinkProvider",
        llvm::json::Object{
@@ -1278,6 +1279,21 @@ void ClangdLSPServer::onReference(const ReferenceParams &Params,
       });
 }
 
+void ClangdLSPServer::onGoToType(const TextDocumentPositionParams &Params,
+                                 Callback<std::vector<Location>> Reply) {
+  Server->findType(
+      Params.textDocument.uri.file(), Params.position,
+      [Reply = std::move(Reply)](
+          llvm::Expected<std::vector<LocatedSymbol>> Types) mutable {
+        if (!Types)
+          return Reply(Types.takeError());
+        std::vector<Location> Response;
+        for (const LocatedSymbol &Sym : *Types)
+          Response.push_back(Sym.PreferredDeclaration);
+        return Reply(std::move(Response));
+      });
+}
+
 void ClangdLSPServer::onGoToImplementation(
     const TextDocumentPositionParams &Params,
     Callback<std::vector<Location>> Reply) {
@@ -1448,6 +1464,7 @@ void ClangdLSPServer::bindMethods(LSPBinder &Bind,
   Bind.method("textDocument/signatureHelp", this, &ClangdLSPServer::onSignatureHelp);
   Bind.method("textDocument/definition", this, &ClangdLSPServer::onGoToDefinition);
   Bind.method("textDocument/declaration", this, &ClangdLSPServer::onGoToDeclaration);
+  Bind.method("textDocument/typeDefinition", this, &ClangdLSPServer::onGoToType);
   Bind.method("textDocument/implementation", this, &ClangdLSPServer::onGoToImplementation);
   Bind.method("textDocument/references", this, &ClangdLSPServer::onReference);
   Bind.method("textDocument/switchSourceHeader", this, &ClangdLSPServer::onSwitchSourceHeader);

diff  --git a/clang-tools-extra/clangd/ClangdLSPServer.h b/clang-tools-extra/clangd/ClangdLSPServer.h
index d2b447e75dbad..02c2a5c721e1d 100644
--- a/clang-tools-extra/clangd/ClangdLSPServer.h
+++ b/clang-tools-extra/clangd/ClangdLSPServer.h
@@ -121,6 +121,8 @@ class ClangdLSPServer : private ClangdServer::Callbacks,
                          Callback<std::vector<Location>>);
   void onGoToDefinition(const TextDocumentPositionParams &,
                         Callback<std::vector<Location>>);
+  void onGoToType(const TextDocumentPositionParams &,
+                  Callback<std::vector<Location>>);
   void onGoToImplementation(const TextDocumentPositionParams &,
                             Callback<std::vector<Location>>);
   void onReference(const ReferenceParams &, Callback<std::vector<Location>>);

diff  --git a/clang-tools-extra/clangd/ClangdServer.cpp b/clang-tools-extra/clangd/ClangdServer.cpp
index 69f37662c4ab7..29d5d9dfe1d1d 100644
--- a/clang-tools-extra/clangd/ClangdServer.cpp
+++ b/clang-tools-extra/clangd/ClangdServer.cpp
@@ -810,6 +810,17 @@ void ClangdServer::foldingRanges(llvm::StringRef File,
                             Transient);
 }
 
+void ClangdServer::findType(llvm::StringRef File, Position Pos,
+                            Callback<std::vector<LocatedSymbol>> CB) {
+  auto Action =
+      [Pos, CB = std::move(CB)](llvm::Expected<InputsAndAST> InpAST) mutable {
+        if (!InpAST)
+          return CB(InpAST.takeError());
+        CB(clangd::findType(InpAST->AST, Pos));
+      };
+  WorkScheduler->runWithAST("FindType", File, std::move(Action));
+}
+
 void ClangdServer::findImplementations(
     PathRef File, Position Pos, Callback<std::vector<LocatedSymbol>> CB) {
   auto Action = [Pos, CB = std::move(CB),

diff  --git a/clang-tools-extra/clangd/ClangdServer.h b/clang-tools-extra/clangd/ClangdServer.h
index 14be307e6fda8..4bef653cbf82d 100644
--- a/clang-tools-extra/clangd/ClangdServer.h
+++ b/clang-tools-extra/clangd/ClangdServer.h
@@ -282,6 +282,10 @@ class ClangdServer {
   void findImplementations(PathRef File, Position Pos,
                            Callback<std::vector<LocatedSymbol>> CB);
 
+  /// Retrieve symbols for types referenced at \p Pos.
+  void findType(PathRef File, Position Pos,
+                Callback<std::vector<LocatedSymbol>> CB);
+
   /// Retrieve locations for symbol references.
   void findReferences(PathRef File, Position Pos, uint32_t Limit,
                       Callback<ReferencesResult> CB);

diff  --git a/clang-tools-extra/clangd/XRefs.cpp b/clang-tools-extra/clangd/XRefs.cpp
index 28ee2096f3fb0..fbe60849daa4a 100644
--- a/clang-tools-extra/clangd/XRefs.cpp
+++ b/clang-tools-extra/clangd/XRefs.cpp
@@ -29,11 +29,13 @@
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/DeclTemplate.h"
+#include "clang/AST/DeclVisitor.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExternalASTSource.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtCXX.h"
+#include "clang/AST/StmtVisitor.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/LLVM.h"
@@ -503,6 +505,8 @@ std::vector<LocatedSymbol> locateSymbolForType(const ParsedAST &AST,
     return {};
   }
 
+  // FIXME: this sends unique_ptr<Foo> to unique_ptr<T>.
+  // Likely it would be better to send it to Foo (heuristically) or to both.
   auto Decls = targetDecl(DynTypedNode::create(Type.getNonReferenceType()),
                           DeclRelation::TemplatePattern | DeclRelation::Alias,
                           AST.getHeuristicResolver());
@@ -1785,6 +1789,182 @@ const CXXRecordDecl *findRecordTypeAt(ParsedAST &AST, Position Pos) {
   return Result;
 }
 
+// Return the type most associated with an AST node.
+// This isn't precisely defined: we want "go to type" to do something useful.
+static QualType typeForNode(const SelectionTree::Node *N) {
+  // If we're looking at a namespace qualifier, walk up to what it's qualifying.
+  // (If we're pointing at a *class* inside a NNS, N will be a TypeLoc).
+  while (N && N->ASTNode.get<NestedNameSpecifierLoc>())
+    N = N->Parent;
+  if (!N)
+    return QualType();
+
+  // If we're pointing at a type => return it.
+  if (const TypeLoc *TL = N->ASTNode.get<TypeLoc>()) {
+    if (llvm::isa<DeducedType>(TL->getTypePtr()))
+      if (auto Deduced = getDeducedType(
+              N->getDeclContext().getParentASTContext(), TL->getBeginLoc()))
+        return *Deduced;
+    // Exception: an alias => underlying type.
+    if (llvm::isa<TypedefType>(TL->getTypePtr()))
+      return TL->getTypePtr()->getLocallyUnqualifiedSingleStepDesugaredType();
+    return TL->getType();
+  }
+
+  // Constructor initializers => the type of thing being initialized.
+  if (const auto *CCI = N->ASTNode.get<CXXCtorInitializer>()) {
+    if (const FieldDecl *FD = CCI->getAnyMember())
+      return FD->getType();
+    if (const Type *Base = CCI->getBaseClass())
+      return QualType(Base, 0);
+  }
+
+  // Base specifier => the base type.
+  if (const auto *CBS = N->ASTNode.get<CXXBaseSpecifier>())
+    return CBS->getType();
+
+  if (const Decl *D = N->ASTNode.get<Decl>()) {
+    struct Visitor : ConstDeclVisitor<Visitor, QualType> {
+      QualType VisitValueDecl(const ValueDecl *D) { return D->getType(); }
+      // Declaration of a type => that type.
+      QualType VisitTypeDecl(const TypeDecl *D) {
+        return QualType(D->getTypeForDecl(), 0);
+      }
+      // Exception: alias declaration => the underlying type, not the alias.
+      QualType VisitTypedefNameDecl(const TypedefNameDecl *D) {
+        return D->getUnderlyingType();
+      }
+      // Look inside templates.
+      QualType VisitTemplateDecl(const TemplateDecl *D) {
+        return Visit(D->getTemplatedDecl());
+      }
+    } V;
+    return V.Visit(D);
+  }
+
+  if (const Stmt *S = N->ASTNode.get<Stmt>()) {
+    struct Visitor : ConstStmtVisitor<Visitor, QualType> {
+      // Null-safe version of visit simplifies recursive calls below.
+      QualType type(const Stmt *S) { return S ? Visit(S) : QualType(); }
+
+      // In general, expressions => type of expression.
+      QualType VisitExpr(const Expr *S) {
+        return S->IgnoreImplicitAsWritten()->getType();
+      }
+      // Exceptions for void expressions that operate on a type in some way.
+      QualType VisitCXXDeleteExpr(const CXXDeleteExpr *S) {
+        return S->getDestroyedType();
+      }
+      QualType VisitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *S) {
+        return S->getDestroyedType();
+      }
+      QualType VisitCXXThrowExpr(const CXXThrowExpr *S) {
+        return S->getSubExpr()->getType();
+      }
+      QualType VisitCoyieldStmt(const CoyieldExpr *S) {
+        return type(S->getOperand());
+      }
+      // Treat a designated initializer like a reference to the field.
+      QualType VisitDesignatedInitExpr(const DesignatedInitExpr *S) {
+        // In .foo.bar we want to jump to bar's type, so find *last* field.
+        for (auto &D : llvm::reverse(S->designators()))
+          if (D.isFieldDesignator())
+            if (const auto *FD = D.getField())
+              return FD->getType();
+        return QualType();
+      }
+
+      // Control flow statements that operate on data: use the data type.
+      QualType VisitSwitchStmt(const SwitchStmt *S) {
+        return type(S->getCond());
+      }
+      QualType VisitWhileStmt(const WhileStmt *S) { return type(S->getCond()); }
+      QualType VisitDoStmt(const DoStmt *S) { return type(S->getCond()); }
+      QualType VisitIfStmt(const IfStmt *S) { return type(S->getCond()); }
+      QualType VisitCaseStmt(const CaseStmt *S) { return type(S->getLHS()); }
+      QualType VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
+        return S->getLoopVariable()->getType();
+      }
+      QualType VisitReturnStmt(const ReturnStmt *S) {
+        return type(S->getRetValue());
+      }
+      QualType VisitCoreturnStmt(const CoreturnStmt *S) {
+        return type(S->getOperand());
+      }
+      QualType VisitCXXCatchStmt(const CXXCatchStmt *S) {
+        return S->getCaughtType();
+      }
+      QualType VisitObjCAtThrowStmt(const ObjCAtThrowStmt *S) {
+        return type(S->getThrowExpr());
+      }
+      QualType VisitObjCAtCatchStmt(const ObjCAtCatchStmt *S) {
+        return S->getCatchParamDecl() ? S->getCatchParamDecl()->getType()
+                                      : QualType();
+      }
+    } V;
+    return V.Visit(S);
+  }
+
+  return QualType();
+}
+
+// Given a type targeted by the cursor, return a type that's more interesting
+// to target.
+static QualType unwrapFindType(QualType T) {
+  if (T.isNull())
+    return T;
+
+  // If there's a specific type alias, point at that rather than unwrapping.
+  if (const auto* TDT = T->getAs<TypedefType>())
+    return QualType(TDT, 0);
+
+  // Pointers etc => pointee type.
+  if (const auto *PT = T->getAs<PointerType>())
+    return unwrapFindType(PT->getPointeeType());
+  if (const auto *RT = T->getAs<ReferenceType>())
+    return unwrapFindType(RT->getPointeeType());
+  if (const auto *AT = T->getAsArrayTypeUnsafe())
+    return unwrapFindType(AT->getElementType());
+  // FIXME: use HeuristicResolver to unwrap smart pointers?
+
+  // Function type => return type.
+  if (auto FT = T->getAs<FunctionType>())
+    return unwrapFindType(FT->getReturnType());
+  if (auto CRD = T->getAsCXXRecordDecl()) {
+    if (CRD->isLambda())
+      return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType());
+    // FIXME: more cases we'd prefer the return type of the call operator?
+    //        std::function etc?
+  }
+
+  return T;
+}
+
+std::vector<LocatedSymbol> findType(ParsedAST &AST, Position Pos) {
+  const SourceManager &SM = AST.getSourceManager();
+  auto Offset = positionToOffset(SM.getBufferData(SM.getMainFileID()), Pos);
+  std::vector<LocatedSymbol> Result;
+  if (!Offset) {
+    elog("failed to convert position {0} for findTypes: {1}", Pos,
+         Offset.takeError());
+    return Result;
+  }
+  // The general scheme is: position -> AST node -> type -> declaration.
+  auto SymbolsFromNode =
+      [&AST](const SelectionTree::Node *N) -> std::vector<LocatedSymbol> {
+    QualType Type = unwrapFindType(typeForNode(N));
+    if (Type.isNull())
+      return {};
+    return locateSymbolForType(AST, Type);
+  };
+  SelectionTree::createEach(AST.getASTContext(), AST.getTokens(), *Offset,
+                            *Offset, [&](SelectionTree ST) {
+                              Result = SymbolsFromNode(ST.commonAncestor());
+                              return !Result.empty();
+                            });
+  return Result;
+}
+
 std::vector<const CXXRecordDecl *> typeParents(const CXXRecordDecl *CXXRD) {
   std::vector<const CXXRecordDecl *> Result;
 

diff  --git a/clang-tools-extra/clangd/XRefs.h b/clang-tools-extra/clangd/XRefs.h
index aa0eaa7454c04..71feb82eda19b 100644
--- a/clang-tools-extra/clangd/XRefs.h
+++ b/clang-tools-extra/clangd/XRefs.h
@@ -105,6 +105,12 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &,
 std::vector<LocatedSymbol> findImplementations(ParsedAST &AST, Position Pos,
                                                const SymbolIndex *Index);
 
+/// Returns symbols for types referenced at \p Pos.
+///
+/// For example, given `b^ar()` wher bar return Foo, this function returns the
+/// definition of class Foo.
+std::vector<LocatedSymbol> findType(ParsedAST &AST, Position Pos);
+
 /// Returns references of the symbol at a specified \p Pos.
 /// \p Limit limits the number of results returned (0 means no limit).
 ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit,

diff  --git a/clang-tools-extra/clangd/test/initialize-params.test b/clang-tools-extra/clangd/test/initialize-params.test
index e5701876684e3..14d7895cfa35b 100644
--- a/clang-tools-extra/clangd/test/initialize-params.test
+++ b/clang-tools-extra/clangd/test/initialize-params.test
@@ -121,6 +121,7 @@
 # CHECK-NEXT:        "openClose": true,
 # CHECK-NEXT:        "save": true
 # CHECK-NEXT:      },
+# CHECK-NEXT:      "typeDefinitionProvider": true,
 # CHECK-NEXT:      "typeHierarchyProvider": true
 # CHECK-NEXT:      "workspaceSymbolProvider": true
 # CHECK-NEXT:    },

diff  --git a/clang-tools-extra/clangd/test/type-definition.test b/clang-tools-extra/clangd/test/type-definition.test
new file mode 100644
index 0000000000000..226fff15c14b5
--- /dev/null
+++ b/clang-tools-extra/clangd/test/type-definition.test
@@ -0,0 +1,32 @@
+# RUN: clangd -lit-test < %s | FileCheck -strict-whitespace %s
+{"jsonrpc":"2.0","id":0,"method":"initialize","params":{}}
+---
+{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"uri":"test:///main.cpp","languageId":"cpp","version":1,
+  "text":"class X {};\nauto x = X{};"
+}}}
+---
+{"jsonrpc":"2.0","id":1,"method":"textDocument/typeDefinition","params":{
+  "textDocument":{"uri":"test:///main.cpp"},
+  "position":{"line":1,"character":5}
+}}
+#      CHECK:  "id": 1
+# CHECK-NEXT:  "jsonrpc": "2.0",
+# CHECK-NEXT:  "result": [
+# CHECK-NEXT:    {
+# CHECK-NEXT:      "range": {
+# CHECK-NEXT:        "end": {
+# CHECK-NEXT:          "character": 7,
+# CHECK-NEXT:          "line": 0
+# CHECK-NEXT:        },
+# CHECK-NEXT:        "start": {
+# CHECK-NEXT:          "character": 6,
+# CHECK-NEXT:          "line": 0
+# CHECK-NEXT:        }
+# CHECK-NEXT:      },
+# CHECK-NEXT:      "uri": "file://{{.*}}/clangd-test/main.cpp"
+# CHECK-NEXT:    }
+# CHECK-NEXT:  ]
+---
+{"jsonrpc":"2.0","id":3,"method":"shutdown"}
+---
+{"jsonrpc":"2.0","method":"exit"}

diff  --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp
index f85035a6c2a15..89076835975a6 100644
--- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp
+++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp
@@ -38,10 +38,12 @@ namespace clangd {
 namespace {
 
 using ::testing::AllOf;
+using ::testing::Contains;
 using ::testing::ElementsAre;
 using ::testing::Eq;
 using ::testing::IsEmpty;
 using ::testing::Matcher;
+using ::testing::Not;
 using ::testing::UnorderedElementsAre;
 using ::testing::UnorderedElementsAreArray;
 using ::testing::UnorderedPointwise;
@@ -1781,6 +1783,69 @@ TEST(FindImplementations, CaptureDefintion) {
       << Test;
 }
 
+TEST(FindType, All) {
+  Annotations HeaderA(R"cpp(
+    struct [[Target]] { operator int() const; };
+    struct Aggregate { Target a, b; };
+    Target t;
+
+    template <typename T> class smart_ptr {
+      T& operator*();
+      T* operator->();
+      T* get();
+    };
+  )cpp");
+  auto TU = TestTU::withHeaderCode(HeaderA.code());
+  for (const llvm::StringRef Case : {
+           "str^uct Target;",
+           "T^arget x;",
+           "Target ^x;",
+           "a^uto x = Target{};",
+           "namespace m { Target tgt; } auto x = m^::tgt;",
+           "Target funcCall(); auto x = ^funcCall();",
+           "Aggregate a = { {}, ^{} };",
+           "Aggregate a = { ^.a=t, };",
+           "struct X { Target a; X() : ^a() {} };",
+           "^using T = Target; ^T foo();",
+           "^template <int> Target foo();",
+           "void x() { try {} ^catch(Target e) {} }",
+           "void x() { ^throw t; }",
+           "int x() { ^return t; }",
+           "void x() { ^switch(t) {} }",
+           "void x() { ^delete (Target*)nullptr; }",
+           "Target& ^tref = t;",
+           "void x() { ^if (t) {} }",
+           "void x() { ^while (t) {} }",
+           "void x() { ^do { } while (t); }",
+           "^auto x = []() { return t; };",
+           "Target* ^tptr = &t;",
+           "Target ^tarray[3];",
+       }) {
+    Annotations A(Case);
+    TU.Code = A.code().str();
+    ParsedAST AST = TU.build();
+
+    ASSERT_GT(A.points().size(), 0u) << Case;
+    for (auto Pos : A.points())
+      EXPECT_THAT(findType(AST, Pos),
+                  ElementsAre(Sym("Target", HeaderA.range(), HeaderA.range())))
+          << Case;
+  }
+
+  // FIXME: We'd like these cases to work. Fix them and move above.
+  for (const llvm::StringRef Case : {
+           "smart_ptr<Target> ^tsmart;",
+       }) {
+    Annotations A(Case);
+    TU.Code = A.code().str();
+    ParsedAST AST = TU.build();
+
+    EXPECT_THAT(findType(AST, A.point()),
+                Not(Contains(Sym("Target", HeaderA.range(), HeaderA.range()))))
+        << Case;
+  }
+}
+
 void checkFindRefs(llvm::StringRef Test, bool UseIndex = false) {
   Annotations T(Test);
   auto TU = TestTU::withCode(T.code());


        


More information about the cfe-commits mailing list