[clang] 805f7a4 - [clang] Add `ObjCProtocolLoc` to represent protocol references

David Goldman via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 18 12:25:15 PST 2022


Author: David Goldman
Date: 2022-02-18T15:24:00-05:00
New Revision: 805f7a4fa4ce97277c3b73d0c204fc3aa4b072e1

URL: https://github.com/llvm/llvm-project/commit/805f7a4fa4ce97277c3b73d0c204fc3aa4b072e1
DIFF: https://github.com/llvm/llvm-project/commit/805f7a4fa4ce97277c3b73d0c204fc3aa4b072e1.diff

LOG: [clang] Add `ObjCProtocolLoc` to represent protocol references

Add `ObjCProtocolLoc` which behaves like `TypeLoc` but for
`ObjCProtocolDecl` references.

RecursiveASTVisitor now synthesizes `ObjCProtocolLoc` during traversal
and the `ObjCProtocolLoc` can be stored in a `DynTypedNode`.

In a follow up patch, I'll update clangd to make use of this
to properly support protocol references for hover + goto definition.

Differential Revision: https://reviews.llvm.org/D119363

Added: 
    

Modified: 
    clang/include/clang/AST/ASTFwd.h
    clang/include/clang/AST/ASTTypeTraits.h
    clang/include/clang/AST/RecursiveASTVisitor.h
    clang/include/clang/AST/TypeLoc.h
    clang/lib/AST/ASTTypeTraits.cpp
    clang/lib/AST/ParentMapContext.cpp
    clang/unittests/AST/RecursiveASTVisitorTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/ASTFwd.h b/clang/include/clang/AST/ASTFwd.h
index fdbd603ce5d04..f84b3238e32b5 100644
--- a/clang/include/clang/AST/ASTFwd.h
+++ b/clang/include/clang/AST/ASTFwd.h
@@ -33,6 +33,7 @@ class OMPClause;
 class Attr;
 #define ATTR(A) class A##Attr;
 #include "clang/Basic/AttrList.inc"
+class ObjCProtocolLoc;
 
 } // end namespace clang
 

diff  --git a/clang/include/clang/AST/ASTTypeTraits.h b/clang/include/clang/AST/ASTTypeTraits.h
index 6d96146a4d455..cd6b5143bf790 100644
--- a/clang/include/clang/AST/ASTTypeTraits.h
+++ b/clang/include/clang/AST/ASTTypeTraits.h
@@ -160,6 +160,7 @@ class ASTNodeKind {
     NKI_Attr,
 #define ATTR(A) NKI_##A##Attr,
 #include "clang/Basic/AttrList.inc"
+    NKI_ObjCProtocolLoc,
     NKI_NumberOfKinds
   };
 
@@ -213,6 +214,7 @@ KIND_TO_KIND_ID(Stmt)
 KIND_TO_KIND_ID(Type)
 KIND_TO_KIND_ID(OMPClause)
 KIND_TO_KIND_ID(Attr)
+KIND_TO_KIND_ID(ObjCProtocolLoc)
 KIND_TO_KIND_ID(CXXBaseSpecifier)
 #define DECL(DERIVED, BASE) KIND_TO_KIND_ID(DERIVED##Decl)
 #include "clang/AST/DeclNodes.inc"
@@ -499,7 +501,7 @@ class DynTypedNode {
   /// have storage or unique pointers and thus need to be stored by value.
   llvm::AlignedCharArrayUnion<const void *, TemplateArgument,
                               TemplateArgumentLoc, NestedNameSpecifierLoc,
-                              QualType, TypeLoc>
+                              QualType, TypeLoc, ObjCProtocolLoc>
       Storage;
 };
 
@@ -570,6 +572,10 @@ template <>
 struct DynTypedNode::BaseConverter<CXXBaseSpecifier, void>
     : public PtrConverter<CXXBaseSpecifier> {};
 
+template <>
+struct DynTypedNode::BaseConverter<ObjCProtocolLoc, void>
+    : public ValueConverter<ObjCProtocolLoc> {};
+
 // The only operation we allow on unsupported types is \c get.
 // This allows to conveniently use \c DynTypedNode when having an arbitrary
 // AST node that is not supported, but prevents misuse - a user cannot create

diff  --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index f62dc36de556e..16da64100d424 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -324,6 +324,12 @@ template <typename Derived> class RecursiveASTVisitor {
   /// \returns false if the visitation was terminated early, true otherwise.
   bool TraverseConceptReference(const ConceptReference &C);
 
+  /// Recursively visit an Objective-C protocol reference with location
+  /// information.
+  ///
+  /// \returns false if the visitation was terminated early, true otherwise.
+  bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc);
+
   // ---- Methods on Attrs ----
 
   // Visit an attribute.
@@ -1340,7 +1346,12 @@ DEF_TRAVERSE_TYPELOC(DependentTemplateSpecializationType, {
 DEF_TRAVERSE_TYPELOC(PackExpansionType,
                      { TRY_TO(TraverseTypeLoc(TL.getPatternLoc())); })
 
-DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, {})
+DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, {
+  for (unsigned I = 0, N = TL.getNumProtocols(); I != N; ++I) {
+    ObjCProtocolLoc ProtocolLoc(TL.getProtocol(I), TL.getProtocolLoc(I));
+    TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+  }
+})
 
 DEF_TRAVERSE_TYPELOC(ObjCInterfaceType, {})
 
@@ -1351,6 +1362,10 @@ DEF_TRAVERSE_TYPELOC(ObjCObjectType, {
     TRY_TO(TraverseTypeLoc(TL.getBaseLoc()));
   for (unsigned i = 0, n = TL.getNumTypeArgs(); i != n; ++i)
     TRY_TO(TraverseTypeLoc(TL.getTypeArgTInfo(i)->getTypeLoc()));
+  for (unsigned I = 0, N = TL.getNumProtocols(); I != N; ++I) {
+    ObjCProtocolLoc ProtocolLoc(TL.getProtocol(I), TL.getProtocolLoc(I));
+    TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+  }
 })
 
 DEF_TRAVERSE_TYPELOC(ObjCObjectPointerType,
@@ -1541,12 +1556,16 @@ DEF_TRAVERSE_DECL(
 DEF_TRAVERSE_DECL(ObjCCompatibleAliasDecl, {// FIXME: implement
                                            })
 
-DEF_TRAVERSE_DECL(ObjCCategoryDecl, {// FIXME: implement
+DEF_TRAVERSE_DECL(ObjCCategoryDecl, {
   if (ObjCTypeParamList *typeParamList = D->getTypeParamList()) {
     for (auto typeParam : *typeParamList) {
       TRY_TO(TraverseObjCTypeParamDecl(typeParam));
     }
   }
+  for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+    ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+    TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+  }
 })
 
 DEF_TRAVERSE_DECL(ObjCCategoryImplDecl, {// FIXME: implement
@@ -1555,7 +1574,7 @@ DEF_TRAVERSE_DECL(ObjCCategoryImplDecl, {// FIXME: implement
 DEF_TRAVERSE_DECL(ObjCImplementationDecl, {// FIXME: implement
                                           })
 
-DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {// FIXME: implement
+DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {
   if (ObjCTypeParamList *typeParamList = D->getTypeParamListAsWritten()) {
     for (auto typeParam : *typeParamList) {
       TRY_TO(TraverseObjCTypeParamDecl(typeParam));
@@ -1565,10 +1584,22 @@ DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {// FIXME: implement
   if (TypeSourceInfo *superTInfo = D->getSuperClassTInfo()) {
     TRY_TO(TraverseTypeLoc(superTInfo->getTypeLoc()));
   }
+  if (D->isThisDeclarationADefinition()) {
+    for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+      ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+      TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+    }
+  }
 })
 
-DEF_TRAVERSE_DECL(ObjCProtocolDecl, {// FIXME: implement
-                                    })
+DEF_TRAVERSE_DECL(ObjCProtocolDecl, {
+  if (D->isThisDeclarationADefinition()) {
+    for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+      ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+      TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+    }
+  }
+})
 
 DEF_TRAVERSE_DECL(ObjCMethodDecl, {
   if (D->getReturnTypeSourceInfo()) {
@@ -2409,6 +2440,12 @@ bool RecursiveASTVisitor<Derived>::TraverseConceptReference(
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::TraverseObjCProtocolLoc(
+    ObjCProtocolLoc ProtocolLoc) {
+  return true;
+}
+
 // If shouldVisitImplicitCode() returns false, this method traverses only the
 // syntactic form of InitListExpr.
 // If shouldVisitImplicitCode() return true, this method is called once for

diff  --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h
index 8cfa579a22da7..59dfa8a9a54d8 100644
--- a/clang/include/clang/AST/TypeLoc.h
+++ b/clang/include/clang/AST/TypeLoc.h
@@ -2607,6 +2607,22 @@ class DependentBitIntTypeLoc final
     : public InheritingConcreteTypeLoc<TypeSpecTypeLoc, DependentBitIntTypeLoc,
                                        DependentBitIntType> {};
 
+class ObjCProtocolLoc {
+  ObjCProtocolDecl *Protocol = nullptr;
+  SourceLocation Loc = SourceLocation();
+
+public:
+  ObjCProtocolLoc(ObjCProtocolDecl *protocol, SourceLocation loc)
+      : Protocol(protocol), Loc(loc) {}
+  ObjCProtocolDecl *getProtocol() const { return Protocol; }
+  SourceLocation getLocation() const { return Loc; }
+
+  /// The source range is just the protocol name.
+  SourceRange getSourceRange() const LLVM_READONLY {
+    return SourceRange(Loc, Loc);
+  }
+};
+
 } // namespace clang
 
 #endif // LLVM_CLANG_AST_TYPELOC_H

diff  --git a/clang/lib/AST/ASTTypeTraits.cpp b/clang/lib/AST/ASTTypeTraits.cpp
index b333f4618efb8..64823f77e58a1 100644
--- a/clang/lib/AST/ASTTypeTraits.cpp
+++ b/clang/lib/AST/ASTTypeTraits.cpp
@@ -16,6 +16,7 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclObjC.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/OpenMPClause.h"
 #include "clang/AST/TypeLoc.h"
@@ -52,6 +53,7 @@ const ASTNodeKind::KindInfo ASTNodeKind::AllKindInfo[] = {
     {NKI_None, "Attr"},
 #define ATTR(A) {NKI_Attr, #A "Attr"},
 #include "clang/Basic/AttrList.inc"
+    {NKI_None, "ObjCProtocolLoc"},
 };
 
 bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const {
@@ -193,6 +195,8 @@ void DynTypedNode::print(llvm::raw_ostream &OS,
     QualType(T, 0).print(OS, PP);
   else if (const Attr *A = get<Attr>())
     A->printPretty(OS, PP);
+  else if (const ObjCProtocolLoc *P = get<ObjCProtocolLoc>())
+    P->getProtocol()->print(OS, PP);
   else
     OS << "Unable to print values of type " << NodeKind.asStringRef() << "\n";
 }
@@ -228,5 +232,7 @@ SourceRange DynTypedNode::getSourceRange() const {
     return CBS->getSourceRange();
   if (const auto *A = get<Attr>())
     return A->getRange();
+  if (const ObjCProtocolLoc *P = get<ObjCProtocolLoc>())
+    return P->getSourceRange();
   return SourceRange();
 }

diff  --git a/clang/lib/AST/ParentMapContext.cpp b/clang/lib/AST/ParentMapContext.cpp
index d216be5b59e89..e0d4700e4b10b 100644
--- a/clang/lib/AST/ParentMapContext.cpp
+++ b/clang/lib/AST/ParentMapContext.cpp
@@ -330,6 +330,9 @@ template <>
 DynTypedNode createDynTypedNode(const NestedNameSpecifierLoc &Node) {
   return DynTypedNode::create(Node);
 }
+template <> DynTypedNode createDynTypedNode(const ObjCProtocolLoc &Node) {
+  return DynTypedNode::create(Node);
+}
 /// @}
 
 /// A \c RecursiveASTVisitor that builds a map from nodes to their
@@ -398,11 +401,14 @@ class ParentMapContext::ParentMap::ASTVisitor
     }
   }
 
+  template <typename T> static bool isNull(T Node) { return !Node; }
+  static bool isNull(ObjCProtocolLoc Node) { return false; }
+
   template <typename T, typename MapNodeTy, typename BaseTraverseFn,
             typename MapTy>
   bool TraverseNode(T Node, MapNodeTy MapNode, BaseTraverseFn BaseTraverse,
                     MapTy *Parents) {
-    if (!Node)
+    if (isNull(Node))
       return true;
     addParent(MapNode, Parents);
     ParentStack.push_back(createDynTypedNode(Node));
@@ -433,6 +439,12 @@ class ParentMapContext::ParentMap::ASTVisitor
         AttrNode, AttrNode, [&] { return VisitorBase::TraverseAttr(AttrNode); },
         &Map.PointerParents);
   }
+  bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLocNode) {
+    return TraverseNode(
+        ProtocolLocNode, DynTypedNode::create(ProtocolLocNode),
+        [&] { return VisitorBase::TraverseObjCProtocolLoc(ProtocolLocNode); },
+        &Map.OtherParents);
+  }
 
   // Using generic TraverseNode for Stmt would prevent data-recursion.
   bool dataTraverseStmtPre(Stmt *StmtNode) {

diff  --git a/clang/unittests/AST/RecursiveASTVisitorTest.cpp b/clang/unittests/AST/RecursiveASTVisitorTest.cpp
index f44a5eca18728..9d7ff5947fe53 100644
--- a/clang/unittests/AST/RecursiveASTVisitorTest.cpp
+++ b/clang/unittests/AST/RecursiveASTVisitorTest.cpp
@@ -60,6 +60,12 @@ enum class VisitEvent {
   EndTraverseEnum,
   StartTraverseTypedefType,
   EndTraverseTypedefType,
+  StartTraverseObjCInterface,
+  EndTraverseObjCInterface,
+  StartTraverseObjCProtocol,
+  EndTraverseObjCProtocol,
+  StartTraverseObjCProtocolLoc,
+  EndTraverseObjCProtocolLoc,
 };
 
 class CollectInterestingEvents
@@ -97,18 +103,43 @@ class CollectInterestingEvents
     return Ret;
   }
 
+  bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) {
+    Events.push_back(VisitEvent::StartTraverseObjCInterface);
+    bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID);
+    Events.push_back(VisitEvent::EndTraverseObjCInterface);
+
+    return Ret;
+  }
+
+  bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) {
+    Events.push_back(VisitEvent::StartTraverseObjCProtocol);
+    bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD);
+    Events.push_back(VisitEvent::EndTraverseObjCProtocol);
+
+    return Ret;
+  }
+
+  bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) {
+    Events.push_back(VisitEvent::StartTraverseObjCProtocolLoc);
+    bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc);
+    Events.push_back(VisitEvent::EndTraverseObjCProtocolLoc);
+
+    return Ret;
+  }
+
   std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
 
 private:
   std::vector<VisitEvent> Events;
 };
 
-std::vector<VisitEvent> collectEvents(llvm::StringRef Code) {
+std::vector<VisitEvent> collectEvents(llvm::StringRef Code,
+                                      const Twine &FileName = "input.cc") {
   CollectInterestingEvents Visitor;
   clang::tooling::runToolOnCode(
       std::make_unique<ProcessASTAction>(
           [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
-      Code);
+      Code, FileName);
   return std::move(Visitor).takeEvents();
 }
 } // namespace
@@ -139,3 +170,28 @@ TEST(RecursiveASTVisitorTest, EnumDeclWithBase) {
                           VisitEvent::EndTraverseTypedefType,
                           VisitEvent::EndTraverseEnum));
 }
+
+TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
+  // Check interface and its protocols are visited.
+  llvm::StringRef Code = R"cpp(
+  @protocol Foo
+  @end
+  @protocol Bar
+  @end
+
+  @interface SomeObject <Foo, Bar>
+  @end
+  )cpp";
+
+  EXPECT_THAT(collectEvents(Code, "input.m"),
+              ElementsAre(VisitEvent::StartTraverseObjCProtocol,
+                          VisitEvent::EndTraverseObjCProtocol,
+                          VisitEvent::StartTraverseObjCProtocol,
+                          VisitEvent::EndTraverseObjCProtocol,
+                          VisitEvent::StartTraverseObjCInterface,
+                          VisitEvent::StartTraverseObjCProtocolLoc,
+                          VisitEvent::EndTraverseObjCProtocolLoc,
+                          VisitEvent::StartTraverseObjCProtocolLoc,
+                          VisitEvent::EndTraverseObjCProtocolLoc,
+                          VisitEvent::EndTraverseObjCInterface));
+}


        


More information about the cfe-commits mailing list