[clang-tools-extra] 749c6a7 - [include-cleaner] Ranking of providers based on hints

Kadir Cetinkaya via cfe-commits cfe-commits at lists.llvm.org
Mon Jan 23 06:22:59 PST 2023


Author: Kadir Cetinkaya
Date: 2023-01-23T15:22:47+01:00
New Revision: 749c6a708340f772f72e1d33594cdb51bb28e190

URL: https://github.com/llvm/llvm-project/commit/749c6a708340f772f72e1d33594cdb51bb28e190
DIFF: https://github.com/llvm/llvm-project/commit/749c6a708340f772f72e1d33594cdb51bb28e190.diff

LOG: [include-cleaner] Ranking of providers based on hints

Introduce signals to rank providers of a symbol.

Differential Revision: https://reviews.llvm.org/D139921

Added: 
    clang-tools-extra/include-cleaner/lib/TypesInternal.h

Modified: 
    clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
    clang-tools-extra/include-cleaner/lib/Analysis.cpp
    clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
    clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
    clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
    clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
    clang-tools-extra/include-cleaner/lib/Types.cpp
    clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
    clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp
    clang-tools-extra/include-cleaner/unittests/LocateSymbolTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h b/clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
index 102f5bc21a84c..05cb96ebec1ff 100644
--- a/clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
+++ b/clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
@@ -29,6 +29,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringMap.h"
 #include <memory>
+#include <utility>
 #include <vector>
 
 namespace llvm {
@@ -71,7 +72,11 @@ struct Symbol {
   // Order must match Kind enum!
   std::variant<const Decl *, struct Macro> Storage;
 
-  Symbol(decltype(Storage) Sentinel) : Storage(std::move(Sentinel)) {}
+  // Disambiguation tag to make sure we can call the right constructor from
+  // DenseMapInfo methods.
+  struct SentinelTag {};
+  Symbol(SentinelTag, decltype(Storage) Sentinel)
+      : Storage(std::move(Sentinel)) {}
   friend llvm::DenseMapInfo<Symbol>;
 };
 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Symbol &);
@@ -117,6 +122,7 @@ struct Header {
 
   Kind kind() const { return static_cast<Kind>(Storage.index()); }
   bool operator==(const Header &RHS) const { return Storage == RHS.Storage; }
+  bool operator<(const Header &RHS) const;
 
   const FileEntry *physical() const { return std::get<Physical>(Storage); }
   tooling::stdlib::Header standard() const {
@@ -127,6 +133,13 @@ struct Header {
 private:
   // Order must match Kind enum!
   std::variant<const FileEntry *, tooling::stdlib::Header, StringRef> Storage;
+
+  // Disambiguation tag to make sure we can call the right constructor from
+  // DenseMapInfo methods.
+  struct SentinelTag {};
+  Header(SentinelTag, decltype(Storage) Sentinel)
+      : Storage(std::move(Sentinel)) {}
+  friend llvm::DenseMapInfo<Header>;
 };
 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Header &);
 
@@ -178,8 +191,12 @@ template <> struct DenseMapInfo<clang::include_cleaner::Symbol> {
   using Outer = clang::include_cleaner::Symbol;
   using Base = DenseMapInfo<decltype(Outer::Storage)>;
 
-  static inline Outer getEmptyKey() { return {Base::getEmptyKey()}; }
-  static inline Outer getTombstoneKey() { return {Base::getTombstoneKey()}; }
+  static inline Outer getEmptyKey() {
+    return {Outer::SentinelTag{}, Base::getEmptyKey()};
+  }
+  static inline Outer getTombstoneKey() {
+    return {Outer::SentinelTag{}, Base::getTombstoneKey()};
+  }
   static unsigned getHashValue(const Outer &Val) {
     return Base::getHashValue(Val.Storage);
   }
@@ -202,6 +219,23 @@ template <> struct DenseMapInfo<clang::include_cleaner::Macro> {
     return Base::isEqual(LHS.Definition, RHS.Definition);
   }
 };
+template <> struct DenseMapInfo<clang::include_cleaner::Header> {
+  using Outer = clang::include_cleaner::Header;
+  using Base = DenseMapInfo<decltype(Outer::Storage)>;
+
+  static inline Outer getEmptyKey() {
+    return {Outer::SentinelTag{}, Base::getEmptyKey()};
+  }
+  static inline Outer getTombstoneKey() {
+    return {Outer::SentinelTag{}, Base::getTombstoneKey()};
+  }
+  static unsigned getHashValue(const Outer &Val) {
+    return Base::getHashValue(Val.Storage);
+  }
+  static bool isEqual(const Outer &LHS, const Outer &RHS) {
+    return Base::isEqual(LHS.Storage, RHS.Storage);
+  }
+};
 } // namespace llvm
 
 #endif

diff  --git a/clang-tools-extra/include-cleaner/lib/Analysis.cpp b/clang-tools-extra/include-cleaner/lib/Analysis.cpp
index 1c7ac33cfa080..c5559db57e14c 100644
--- a/clang-tools-extra/include-cleaner/lib/Analysis.cpp
+++ b/clang-tools-extra/include-cleaner/lib/Analysis.cpp
@@ -12,29 +12,19 @@
 #include "clang-include-cleaner/Types.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Format/Format.h"
 #include "clang/Lex/HeaderSearch.h"
 #include "clang/Tooling/Core/Replacement.h"
-#include "clang/Tooling/Inclusions/HeaderIncludes.h"
 #include "clang/Tooling/Inclusions/StandardLibrary.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace clang::include_cleaner {
 
-namespace {
-// Gets all the providers for a symbol by tarversing each location.
-llvm::SmallVector<Header> headersForSymbol(const Symbol &S,
-                                           const SourceManager &SM,
-                                           const PragmaIncludes *PI) {
-  llvm::SmallVector<Header> Headers;
-  for (auto &Loc : locateSymbol(S))
-    Headers.append(findHeaders(Loc, SM, PI));
-  return Headers;
-}
-} // namespace
-
 void walkUsed(llvm::ArrayRef<Decl *> ASTRoots,
               llvm::ArrayRef<SymbolReference> MacroRefs,
               const PragmaIncludes *PI, const SourceManager &SM,
@@ -55,7 +45,7 @@ void walkUsed(llvm::ArrayRef<Decl *> ASTRoots,
     assert(MacroRef.Target.kind() == Symbol::Macro);
     if (!SM.isWrittenInMainFile(SM.getSpellingLoc(MacroRef.RefLocation)))
       continue;
-    CB(MacroRef, findHeaders(MacroRef.Target.macro().Definition, SM, PI));
+    CB(MacroRef, headersForSymbol(MacroRef.Target, SM, PI));
   }
 }
 

diff  --git a/clang-tools-extra/include-cleaner/lib/AnalysisInternal.h b/clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
index 026bb8dd20501..acf462919344b 100644
--- a/clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
+++ b/clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
@@ -21,12 +21,10 @@
 #ifndef CLANG_INCLUDE_CLEANER_ANALYSISINTERNAL_H
 #define CLANG_INCLUDE_CLEANER_ANALYSISINTERNAL_H
 
+#include "TypesInternal.h"
 #include "clang-include-cleaner/Record.h"
 #include "clang-include-cleaner/Types.h"
-#include "clang/Basic/SourceLocation.h"
-#include "clang/Tooling/Inclusions/StandardLibrary.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
-#include <variant>
 #include <vector>
 
 namespace clang {
@@ -34,6 +32,7 @@ class ASTContext;
 class Decl;
 class HeaderSearch;
 class NamedDecl;
+class SourceLocation;
 namespace include_cleaner {
 
 /// Traverses part of the AST from \p Root, finding uses of symbols.
@@ -51,40 +50,20 @@ namespace include_cleaner {
 void walkAST(Decl &Root,
              llvm::function_ref<void(SourceLocation, NamedDecl &, RefType)>);
 
-/// A place where a symbol can be provided.
-/// It is either a physical file of the TU (SourceLocation) or a logical
-/// location in the standard library (stdlib::Symbol).
-struct SymbolLocation {
-  enum Kind {
-    /// A position within a source file (or macro expansion) parsed by clang.
-    Physical,
-    /// A recognized standard library symbol, like std::string.
-    Standard,
-  };
-
-  SymbolLocation(SourceLocation S) : Storage(S) {}
-  SymbolLocation(tooling::stdlib::Symbol S) : Storage(S) {}
+/// Finds the headers that provide the symbol location.
+llvm::SmallVector<Hinted<Header>> findHeaders(const SymbolLocation &Loc,
+                                              const SourceManager &SM,
+                                              const PragmaIncludes *PI);
 
-  Kind kind() const { return static_cast<Kind>(Storage.index()); }
-  bool operator==(const SymbolLocation &RHS) const {
-    return Storage == RHS.Storage;
-  }
-  SourceLocation physical() const { return std::get<Physical>(Storage); }
-  tooling::stdlib::Symbol standard() const {
-    return std::get<Standard>(Storage);
-  }
+/// A set of locations that provides the declaration.
+std::vector<Hinted<SymbolLocation>> locateSymbol(const Symbol &S);
 
-private:
-  // Order must match Kind enum!
-  std::variant<SourceLocation, tooling::stdlib::Symbol> Storage;
-};
-llvm::raw_ostream &operator<<(llvm::raw_ostream &, const SymbolLocation &);
-
-/// Finds the headers that provide the symbol location.
-// FIXME: expose signals
-llvm::SmallVector<Header> findHeaders(const SymbolLocation &Loc,
-                                      const SourceManager &SM,
-                                      const PragmaIncludes *PI);
+/// Gets all the providers for a symbol by traversing each location.
+/// Returned headers are sorted by relevance, first element is the most
+/// likely provider for the symbol.
+llvm::SmallVector<Header> headersForSymbol(const Symbol &S,
+                                           const SourceManager &SM,
+                                           const PragmaIncludes *PI);
 
 /// Write an HTML summary of the analysis to the given stream.
 void writeHTMLReport(FileID File, const Includes &,
@@ -93,9 +72,6 @@ void writeHTMLReport(FileID File, const Includes &,
                      HeaderSearch &HS, PragmaIncludes *PI,
                      llvm::raw_ostream &OS);
 
-/// A set of locations that provides the declaration.
-std::vector<SymbolLocation> locateSymbol(const Symbol &S);
-
 } // namespace include_cleaner
 } // namespace clang
 

diff  --git a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
index fccba48056eed..030ce61701494 100644
--- a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
+++ b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
@@ -7,32 +7,110 @@
 //===----------------------------------------------------------------------===//
 
 #include "AnalysisInternal.h"
+#include "TypesInternal.h"
 #include "clang-include-cleaner/Record.h"
+#include "clang-include-cleaner/Types.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
+#include "clang/Basic/FileEntry.h"
+#include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
+#include "clang/Tooling/Inclusions/StandardLibrary.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <string>
+#include <utility>
 
 namespace clang::include_cleaner {
+namespace {
+llvm::SmallVector<Hinted<Header>>
+applyHints(llvm::SmallVector<Hinted<Header>> Headers, Hints H) {
+  for (auto &Header : Headers)
+    Header.Hint |= H;
+  return Headers;
+}
+
+llvm::SmallVector<Header> ranked(llvm::SmallVector<Hinted<Header>> Headers) {
+  llvm::stable_sort(llvm::reverse(Headers),
+                    [](const Hinted<Header> &LHS, const Hinted<Header> &RHS) {
+                      return LHS < RHS;
+                    });
+  return llvm::SmallVector<Header>(Headers.begin(), Headers.end());
+}
+
+// Return the basename from a verbatim header spelling, leaves only the file
+// name.
+llvm::StringRef basename(llvm::StringRef Header) {
+  Header = Header.trim("<>\"");
+  if (auto LastSlash = Header.rfind('/'); LastSlash != Header.npos)
+    Header = Header.drop_front(LastSlash + 1);
+  // Drop everything after first `.` (dot).
+  // foo.h -> foo
+  // foo.cu.h -> foo
+  Header = Header.substr(0, Header.find('.'));
+  return Header;
+}
 
-llvm::SmallVector<Header> findHeaders(const SymbolLocation &Loc,
-                                      const SourceManager &SM,
-                                      const PragmaIncludes *PI) {
-  llvm::SmallVector<Header> Results;
+// Check if spelling of \p H matches \p DeclName.
+bool nameMatch(llvm::StringRef DeclName, Header H) {
+  switch (H.kind()) {
+  case Header::Physical:
+    return basename(H.physical()->getName()).equals_insensitive(DeclName);
+  case Header::Standard:
+    return basename(H.standard().name()).equals_insensitive(DeclName);
+  case Header::Verbatim:
+    return basename(H.verbatim()).equals_insensitive(DeclName);
+  }
+  llvm_unreachable("unhandled Header kind!");
+}
+
+llvm::StringRef symbolName(const Symbol &S) {
+  switch (S.kind()) {
+  case Symbol::Declaration:
+    // Unnamed decls like operators and anonymous structs won't get any name
+    // match.
+    if (const auto *ND = llvm::dyn_cast<NamedDecl>(&S.declaration()))
+      if (auto *II = ND->getIdentifier())
+        return II->getName();
+    return "";
+  case Symbol::Macro:
+    return S.macro().Name->getName();
+  }
+  llvm_unreachable("unhandled Symbol kind!");
+}
+
+} // namespace
+
+llvm::SmallVector<Hinted<Header>> findHeaders(const SymbolLocation &Loc,
+                                              const SourceManager &SM,
+                                              const PragmaIncludes *PI) {
+  auto IsPublicHeader = [&PI](const FileEntry *FE) {
+    return (PI->isPrivate(FE) || !PI->isSelfContained(FE))
+               ? Hints::None
+               : Hints::PublicHeader;
+  };
+  llvm::SmallVector<Hinted<Header>> Results;
   switch (Loc.kind()) {
   case SymbolLocation::Physical: {
     FileID FID = SM.getFileID(SM.getExpansionLoc(Loc.physical()));
     const FileEntry *FE = SM.getFileEntryForID(FID);
-    if (!PI) {
-      return FE ? llvm::SmallVector<Header>{Header(FE)}
-                : llvm::SmallVector<Header>();
-    }
+    if (!FE)
+      return {};
+    if (!PI)
+      return {{FE, Hints::PublicHeader}};
     while (FE) {
-      Results.push_back(Header(FE));
+      Hints CurrentHints = IsPublicHeader(FE);
+      Results.emplace_back(FE, CurrentHints);
       // FIXME: compute transitive exporter headers.
       for (const auto *Export : PI->getExporters(FE, SM.getFileManager()))
-        Results.push_back(Header(Export));
+        Results.emplace_back(Export, IsPublicHeader(Export));
 
-      llvm::StringRef VerbatimSpelling = PI->getPublic(FE);
-      if (!VerbatimSpelling.empty()) {
-        Results.push_back(VerbatimSpelling);
+      if (auto Verbatim = PI->getPublic(FE); !Verbatim.empty()) {
+        Results.emplace_back(Verbatim,
+                             Hints::PublicHeader | Hints::PreferredHeader);
         break;
       }
       if (PI->isSelfContained(FE) || FID == SM.getMainFileID())
@@ -46,14 +124,54 @@ llvm::SmallVector<Header> findHeaders(const SymbolLocation &Loc,
   }
   case SymbolLocation::Standard: {
     for (const auto &H : Loc.standard().headers()) {
-      Results.push_back(H);
+      Results.emplace_back(H, Hints::PublicHeader);
       for (const auto *Export : PI->getExporters(H, SM.getFileManager()))
-        Results.push_back(Header(Export));
+        Results.emplace_back(Header(Export), IsPublicHeader(Export));
     }
+    // StandardLibrary returns headers in preference order, so only mark the
+    // first.
+    if (!Results.empty())
+      Results.front().Hint |= Hints::PreferredHeader;
     return Results;
   }
   }
   llvm_unreachable("unhandled SymbolLocation kind!");
 }
 
+llvm::SmallVector<Header> headersForSymbol(const Symbol &S,
+                                           const SourceManager &SM,
+                                           const PragmaIncludes *PI) {
+  // Get headers for all the locations providing Symbol. Same header can be
+  // reached through 
diff erent traversals, deduplicate those into a single
+  // Header by merging their hints.
+  llvm::SmallVector<Hinted<Header>> Headers;
+  for (auto &Loc : locateSymbol(S))
+    Headers.append(applyHints(findHeaders(Loc, SM, PI), Loc.Hint));
+  // If two Headers probably refer to the same file (e.g. Verbatim(foo.h) and
+  // Physical(/path/to/foo.h), we won't deduplicate them or merge their hints
+  llvm::stable_sort(
+      Headers, [](const Hinted<Header> &LHS, const Hinted<Header> &RHS) {
+        return static_cast<Header>(LHS) < static_cast<Header>(RHS);
+      });
+  auto *Write = Headers.begin();
+  for (auto *Read = Headers.begin(); Read != Headers.end(); ++Write) {
+    *Write = *Read++;
+    while (Read != Headers.end() &&
+           static_cast<Header>(*Write) == static_cast<Header>(*Read)) {
+      Write->Hint |= Read->Hint;
+      ++Read;
+    }
+  }
+  Headers.erase(Write, Headers.end());
+
+  // Add name match hints to deduplicated providers.
+  llvm::StringRef SymbolName = symbolName(S);
+  for (auto &H : Headers) {
+    if (nameMatch(SymbolName, H))
+      H.Hint |= Hints::PreferredHeader;
+  }
+
+  // FIXME: Introduce a MainFile header kind or signal and boost it.
+  return ranked(std::move(Headers));
+}
 } // namespace clang::include_cleaner

diff  --git a/clang-tools-extra/include-cleaner/lib/HTMLReport.cpp b/clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
index be78d0b4aa778..c1d1982d4f487 100644
--- a/clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
+++ b/clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
@@ -187,8 +187,7 @@ class Reporter {
     // Duplicates logic from walkUsed(), which doesn't expose SymbolLocations.
     for (auto &Loc : locateSymbol(R.Sym))
       R.Locations.push_back(Loc);
-    for (const auto &Loc : R.Locations)
-      R.Headers.append(findHeaders(Loc, SM, PI));
+    R.Headers = headersForSymbol(R.Sym, SM, PI);
 
     for (const auto &H : R.Headers) {
       R.Includes.append(Includes.match(H));
@@ -205,7 +204,6 @@ class Reporter {
                      R.Includes.end());
 
     if (!R.Headers.empty())
-      // FIXME: library should tell us which header to use.
       R.Insert = spellHeader(R.Headers.front());
   }
 

diff  --git a/clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp b/clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
index c8b6cc77eb148..60b18c2fe94fb 100644
--- a/clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
+++ b/clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
@@ -7,51 +7,60 @@
 //===----------------------------------------------------------------------===//
 
 #include "AnalysisInternal.h"
+#include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
-#include "clang/Basic/SourceLocation.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclTemplate.h"
 #include "clang/Tooling/Inclusions/StandardLibrary.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/Casting.h"
 #include <utility>
 #include <vector>
 
 namespace clang::include_cleaner {
 namespace {
 
-std::vector<SymbolLocation> locateDecl(const Decl &D) {
-  std::vector<SymbolLocation> Result;
+template <typename T> Hints completeIfDefinition(T *D) {
+  return D->isThisDeclarationADefinition() ? Hints::CompleteSymbol
+                                           : Hints::None;
+}
+
+Hints declHints(const Decl *D) {
+  // Definition is only needed for classes and templates for completeness.
+  if (auto *TD = llvm::dyn_cast<TagDecl>(D))
+    return completeIfDefinition(TD);
+  else if (auto *CTD = llvm::dyn_cast<ClassTemplateDecl>(D))
+    return completeIfDefinition(CTD);
+  else if (auto *FTD = llvm::dyn_cast<FunctionTemplateDecl>(D))
+    return completeIfDefinition(FTD);
+  // Any other declaration is assumed usable.
+  return Hints::CompleteSymbol;
+}
+
+std::vector<Hinted<SymbolLocation>> locateDecl(const Decl &D) {
+  std::vector<Hinted<SymbolLocation>> Result;
   // FIXME: Should we also provide physical locations?
   if (auto SS = tooling::stdlib::Recognizer()(&D))
-    return {SymbolLocation(*SS)};
+    return {{*SS, Hints::CompleteSymbol}};
+  // FIXME: Signal foreign decls, e.g. a forward declaration not owned by a
+  // library. Some useful signals could be derived by checking the DeclContext.
+  // Most incidental forward decls look like:
+  //   namespace clang {
+  //   class SourceManager; // likely an incidental forward decl.
+  //   namespace my_own_ns {}
+  //   }
   for (auto *Redecl : D.redecls())
-    Result.push_back(Redecl->getLocation());
+    Result.push_back({Redecl->getLocation(), declHints(Redecl)});
   return Result;
 }
 
 } // namespace
 
-llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const SymbolLocation &S) {
-  switch (S.kind()) {
-  case SymbolLocation::Physical:
-    // We can't decode the Location without SourceManager. Its raw
-    // representation isn't completely useless (and distinguishes
-    // SymbolReference from Symbol).
-    return OS << "@0x"
-              << llvm::utohexstr(
-                     S.physical().getRawEncoding(), /*LowerCase=*/false,
-                     /*Width=*/CHAR_BIT * sizeof(SourceLocation::UIntTy));
-  case SymbolLocation::Standard:
-    return OS << S.standard().scope() << S.standard().name();
-  }
-  llvm_unreachable("Unhandled Symbol kind");
-}
-
-std::vector<SymbolLocation> locateSymbol(const Symbol &S) {
+std::vector<Hinted<SymbolLocation>> locateSymbol(const Symbol &S) {
   switch (S.kind()) {
   case Symbol::Declaration:
     return locateDecl(S.declaration());
   case Symbol::Macro:
-    return {SymbolLocation(S.macro().Definition)};
+    return {{S.macro().Definition, Hints::CompleteSymbol}};
   }
   llvm_unreachable("Unknown Symbol::Kind enum");
 }

diff  --git a/clang-tools-extra/include-cleaner/lib/Types.cpp b/clang-tools-extra/include-cleaner/lib/Types.cpp
index 68311546f724c..d7b040717a39e 100644
--- a/clang-tools-extra/include-cleaner/lib/Types.cpp
+++ b/clang-tools-extra/include-cleaner/lib/Types.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang-include-cleaner/Types.h"
+#include "TypesInternal.h"
 #include "clang/AST/Decl.h"
 #include "clang/Basic/FileEntry.h"
 #include "llvm/ADT/StringExtras.h"
@@ -106,4 +107,32 @@ llvm::SmallVector<const Include *> Includes::match(Header H) const {
   return Result;
 }
 
+llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const SymbolLocation &S) {
+  switch (S.kind()) {
+  case SymbolLocation::Physical:
+    // We can't decode the Location without SourceManager. Its raw
+    // representation isn't completely useless (and distinguishes
+    // SymbolReference from Symbol).
+    return OS << "@0x"
+              << llvm::utohexstr(
+                     S.physical().getRawEncoding(), /*LowerCase=*/false,
+                     /*Width=*/CHAR_BIT * sizeof(SourceLocation::UIntTy));
+  case SymbolLocation::Standard:
+    return OS << S.standard().scope() << S.standard().name();
+  }
+  llvm_unreachable("Unhandled Symbol kind");
+}
+
+bool Header::operator<(const Header &RHS) const {
+  if (kind() != RHS.kind())
+    return kind() < RHS.kind();
+  switch (kind()) {
+  case Header::Physical:
+    return physical()->getName() < RHS.physical()->getName();
+  case Header::Standard:
+    return standard().name() < RHS.standard().name();
+  case Header::Verbatim:
+    return verbatim() < RHS.verbatim();
+  }
+}
 } // namespace clang::include_cleaner

diff  --git a/clang-tools-extra/include-cleaner/lib/TypesInternal.h b/clang-tools-extra/include-cleaner/lib/TypesInternal.h
new file mode 100644
index 0000000000000..09a3933577a94
--- /dev/null
+++ b/clang-tools-extra/include-cleaner/lib/TypesInternal.h
@@ -0,0 +1,93 @@
+//===--- TypesInternal.h - Intermediate structures used for analysis C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef CLANG_INCLUDE_CLEANER_TYPESINTERNAL_H
+#define CLANG_INCLUDE_CLEANER_TYPESINTERNAL_H
+
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Tooling/Inclusions/StandardLibrary.h"
+#include "llvm/ADT/BitmaskEnum.h"
+#include <cstdint>
+#include <utility>
+#include <variant>
+
+namespace llvm {
+class raw_ostream;
+}
+namespace clang::include_cleaner {
+/// A place where a symbol can be provided.
+/// It is either a physical file of the TU (SourceLocation) or a logical
+/// location in the standard library (stdlib::Symbol).
+struct SymbolLocation {
+  enum Kind {
+    /// A position within a source file (or macro expansion) parsed by clang.
+    Physical,
+    /// A recognized standard library symbol, like std::string.
+    Standard,
+  };
+
+  SymbolLocation(SourceLocation S) : Storage(S) {}
+  SymbolLocation(tooling::stdlib::Symbol S) : Storage(S) {}
+
+  Kind kind() const { return static_cast<Kind>(Storage.index()); }
+  bool operator==(const SymbolLocation &RHS) const {
+    return Storage == RHS.Storage;
+  }
+  SourceLocation physical() const { return std::get<Physical>(Storage); }
+  tooling::stdlib::Symbol standard() const {
+    return std::get<Standard>(Storage);
+  }
+
+private:
+  // Order must match Kind enum!
+  std::variant<SourceLocation, tooling::stdlib::Symbol> Storage;
+};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &, const SymbolLocation &);
+
+/// Represents properties of a symbol provider.
+///
+/// Hints represents the properties of the edges traversed when finding headers
+/// that satisfy an AST node (AST node => symbols => locations => headers).
+///
+/// Since there can be multiple paths from an AST node to same header, we need
+/// to merge hints. These hints are merged by taking the union of all the
+/// properties along all the paths. We choose the boolean sense accordingly,
+/// e.g. "Public" rather than "Private", because a header is good if it provides
+/// any public definition, even if it also provides private ones.
+///
+/// Hints are sorted in ascending order of relevance.
+enum class Hints : uint8_t {
+  None = 0x00,
+  /// Provides a generally-usable definition for the symbol. (a function decl,
+  /// or class definition and not a forward declaration of a template).
+  CompleteSymbol = 1 << 0,
+  /// Symbol is provided by a public file. Only absent in the cases where file
+  /// is explicitly marked as such, non self-contained or IWYU private
+  /// pragmas.
+  PublicHeader = 1 << 1,
+  /// Header providing the symbol is explicitly marked as preferred, with an
+  /// IWYU private pragma that points at this provider or header and symbol has
+  /// ~the same name.
+  PreferredHeader = 1 << 2,
+  LLVM_MARK_AS_BITMASK_ENUM(PreferredHeader),
+};
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+/// A wrapper to augment values with hints.
+template <typename T> struct Hinted : public T {
+  Hints Hint;
+  Hinted(T &&Wrapped, Hints H) : T(std::move(Wrapped)), Hint(H) {}
+
+  /// Since hints are sorted by relevance, use it directly.
+  bool operator<(const Hinted<T> &Other) const {
+    return static_cast<int>(Hint) < static_cast<int>(Other.Hint);
+  }
+};
+
+} // namespace clang::include_cleaner
+
+#endif

diff  --git a/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp b/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
index 04ff428184667..967ae829e778d 100644
--- a/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
+++ b/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang-include-cleaner/Analysis.h"
+#include "AnalysisInternal.h"
 #include "clang-include-cleaner/Record.h"
 #include "clang-include-cleaner/Types.h"
 #include "clang/AST/ASTContext.h"
@@ -365,7 +366,7 @@ TEST(WalkUsed, FilterRefsNotSpelledInMainFile) {
     FileID MainFID = SM.getMainFileID();
     if (RefLoc.isValid()) {
       EXPECT_THAT(RefLoc, AllOf(expandedAt(MainFID, Main.point("expand"), &SM),
-                                 spelledAt(MainFID, Main.point("spell"), &SM)))
+                                spelledAt(MainFID, Main.point("spell"), &SM)))
           << T.Main;
     } else {
       EXPECT_THAT(Main.points(), testing::IsEmpty());
@@ -373,5 +374,17 @@ TEST(WalkUsed, FilterRefsNotSpelledInMainFile) {
   }
 }
 
+TEST(Hints, Ordering) {
+  struct Tag {};
+  auto Hinted = [](Hints Hints) {
+    return clang::include_cleaner::Hinted<Tag>({}, Hints);
+  };
+  EXPECT_LT(Hinted(Hints::None), Hinted(Hints::CompleteSymbol));
+  EXPECT_LT(Hinted(Hints::CompleteSymbol), Hinted(Hints::PublicHeader));
+  EXPECT_LT(Hinted(Hints::PublicHeader), Hinted(Hints::PreferredHeader));
+  EXPECT_LT(Hinted(Hints::CompleteSymbol | Hints::PublicHeader),
+            Hinted(Hints::PreferredHeader));
+}
+
 } // namespace
 } // namespace clang::include_cleaner

diff  --git a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp
index d47b438ce5f5e..f414ff314418d 100644
--- a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp
+++ b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "AnalysisInternal.h"
-#include "clang-include-cleaner/Analysis.h"
+#include "TypesInternal.h"
 #include "clang-include-cleaner/Record.h"
 #include "clang-include-cleaner/Types.h"
 #include "clang/AST/RecursiveASTVisitor.h"
@@ -16,15 +16,14 @@
 #include "clang/Frontend/FrontendActions.h"
 #include "clang/Testing/TestAST.h"
 #include "clang/Tooling/Inclusions/StandardLibrary.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/Support/raw_ostream.h"
-#include "llvm/Testing/Annotations/Annotations.h"
+#include "llvm/ADT/SmallVector.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <memory>
 
 namespace clang::include_cleaner {
 namespace {
+using testing::ElementsAre;
 using testing::UnorderedElementsAre;
 
 std::string guard(llvm::StringRef Code) {
@@ -53,7 +52,7 @@ class FindHeadersTest : public testing::Test {
   }
   void buildAST() { AST = std::make_unique<TestAST>(Inputs); }
 
-  llvm::SmallVector<Header> findHeaders(llvm::StringRef FileName) {
+  llvm::SmallVector<Hinted<Header>> findHeaders(llvm::StringRef FileName) {
     return include_cleaner::findHeaders(
         AST->sourceManager().translateFileLineCol(
             AST->fileManager().getFile(FileName).get(),
@@ -225,12 +224,211 @@ TEST_F(FindHeadersTest, TargetIsExpandedFromMacroInHeader) {
     CustomVisitor Visitor;
     Visitor.TraverseDecl(AST->context().getTranslationUnitDecl());
 
-    llvm::SmallVector<Header> Headers = clang::include_cleaner::findHeaders(
+    auto Headers = clang::include_cleaner::findHeaders(
         Visitor.Out->getLocation(), AST->sourceManager(),
         /*PragmaIncludes=*/nullptr);
     EXPECT_THAT(Headers, UnorderedElementsAre(physicalHeader("declare.h")));
   }
 }
 
+MATCHER_P2(HintedHeader, Header, Hint, "") {
+  return std::tie(arg.Hint, arg) == std::tie(Hint, Header);
+}
+
+TEST_F(FindHeadersTest, PublicHeaderHint) {
+  Inputs.Code = R"cpp(
+    #include "public.h"
+  )cpp";
+  Inputs.ExtraFiles["public.h"] = guard(R"cpp(
+    #include "private.h"
+    #include "private.inc"
+  )cpp");
+  Inputs.ExtraFiles["private.h"] = guard(R"cpp(
+    // IWYU pragma: private
+  )cpp");
+  Inputs.ExtraFiles["private.inc"] = "";
+  buildAST();
+  // Non self-contained files and headers marked with IWYU private pragma
+  // shouldn't have PublicHeader hint.
+  EXPECT_THAT(
+      findHeaders("private.inc"),
+      UnorderedElementsAre(
+          HintedHeader(physicalHeader("private.inc"), Hints::None),
+          HintedHeader(physicalHeader("public.h"), Hints::PublicHeader)));
+  EXPECT_THAT(findHeaders("private.h"),
+              UnorderedElementsAre(
+                  HintedHeader(physicalHeader("private.h"), Hints::None)));
+}
+
+TEST_F(FindHeadersTest, PreferredHeaderHint) {
+  Inputs.Code = R"cpp(
+    #include "private.h"
+  )cpp";
+  Inputs.ExtraFiles["private.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "public.h"
+  )cpp");
+  buildAST();
+  // Headers explicitly marked should've preferred signal.
+  EXPECT_THAT(findHeaders("private.h"),
+              UnorderedElementsAre(
+                  HintedHeader(physicalHeader("private.h"), Hints::None),
+                  HintedHeader(Header("\"public.h\""),
+                               Hints::PreferredHeader | Hints::PublicHeader)));
+}
+
+class HeadersForSymbolTest : public FindHeadersTest {
+protected:
+  llvm::SmallVector<Header> headersForFoo() {
+    struct Visitor : public RecursiveASTVisitor<Visitor> {
+      const NamedDecl *Out = nullptr;
+      bool VisitNamedDecl(const NamedDecl *ND) {
+        if (ND->getName() == "foo") {
+          EXPECT_TRUE(Out == nullptr || Out == ND->getCanonicalDecl())
+              << "Found multiple matches for foo.";
+          Out = cast<NamedDecl>(ND->getCanonicalDecl());
+        }
+        return true;
+      }
+    };
+    Visitor V;
+    V.TraverseDecl(AST->context().getTranslationUnitDecl());
+    if (!V.Out)
+      ADD_FAILURE() << "Couldn't find any decls named foo.";
+    assert(V.Out);
+    return headersForSymbol(*V.Out, AST->sourceManager(), &PI);
+  }
+};
+
+TEST_F(HeadersForSymbolTest, Deduplicates) {
+  Inputs.Code = R"cpp(
+    #include "foo.h"
+  )cpp";
+  Inputs.ExtraFiles["foo.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "foo.h"
+    void foo();
+    void foo();
+  )cpp");
+  buildAST();
+  EXPECT_THAT(
+      headersForFoo(),
+      UnorderedElementsAre(physicalHeader("foo.h"),
+                           // FIXME: de-duplicate across 
diff erent kinds.
+                           Header("\"foo.h\"")));
+}
+
+TEST_F(HeadersForSymbolTest, RankByName) {
+  Inputs.Code = R"cpp(
+    #include "fox.h"
+    #include "bar.h"
+  )cpp";
+  Inputs.ExtraFiles["fox.h"] = guard(R"cpp(
+    void foo();
+  )cpp");
+  Inputs.ExtraFiles["bar.h"] = guard(R"cpp(
+    void foo();
+  )cpp");
+  buildAST();
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("bar.h"), physicalHeader("fox.h")));
+}
+
+TEST_F(HeadersForSymbolTest, Ranking) {
+  // Sorting is done over (canonical, public, complete) triplet.
+  Inputs.Code = R"cpp(
+    #include "private.h"
+    #include "public.h"
+    #include "public_complete.h"
+  )cpp";
+  Inputs.ExtraFiles["public.h"] = guard(R"cpp(
+    struct foo;
+  )cpp");
+  Inputs.ExtraFiles["private.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "canonical.h"
+    struct foo;
+  )cpp");
+  Inputs.ExtraFiles["public_complete.h"] = guard("struct foo {};");
+  buildAST();
+  EXPECT_THAT(headersForFoo(), ElementsAre(Header("\"canonical.h\""),
+                                           physicalHeader("public_complete.h"),
+                                           physicalHeader("public.h"),
+                                           physicalHeader("private.h")));
+}
+
+TEST_F(HeadersForSymbolTest, PreferPublicOverComplete) {
+  Inputs.Code = R"cpp(
+    #include "complete_private.h"
+    #include "public.h"
+  )cpp";
+  Inputs.ExtraFiles["complete_private.h"] = guard(R"cpp(
+    // IWYU pragma: private
+    struct foo {};
+  )cpp");
+  Inputs.ExtraFiles["public.h"] = guard("struct foo;");
+  buildAST();
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("public.h"),
+                          physicalHeader("complete_private.h")));
+}
+
+TEST_F(HeadersForSymbolTest, PreferNameMatch) {
+  Inputs.Code = R"cpp(
+    #include "public_complete.h"
+    #include "test/foo.fwd.h"
+  )cpp";
+  Inputs.ExtraFiles["public_complete.h"] = guard(R"cpp(
+    struct foo {};
+  )cpp");
+  Inputs.ExtraFiles["test/foo.fwd.h"] = guard("struct foo;");
+  buildAST();
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("test/foo.fwd.h"),
+                          physicalHeader("public_complete.h")));
+}
+
+TEST_F(HeadersForSymbolTest, MainFile) {
+  Inputs.Code = R"cpp(
+    #include "public_complete.h"
+    struct foo;
+  )cpp";
+  Inputs.ExtraFiles["public_complete.h"] = guard(R"cpp(
+    struct foo {};
+  )cpp");
+  buildAST();
+  auto &SM = AST->sourceManager();
+  // FIXME: Symbols provided by main file should be treated specially.
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("public_complete.h"),
+                          Header(SM.getFileEntryForID(SM.getMainFileID()))));
+}
+
+TEST_F(HeadersForSymbolTest, PreferExporterOfPrivate) {
+  Inputs.Code = R"cpp(
+    #include "private.h"
+    #include "exporter.h"
+  )cpp";
+  Inputs.ExtraFiles["private.h"] = guard(R"cpp(
+    // IWYU pragma: private
+    struct foo {};
+  )cpp");
+  Inputs.ExtraFiles["exporter.h"] = guard(R"cpp(
+    #include "private.h" // IWYU pragma: export
+  )cpp");
+  buildAST();
+  EXPECT_THAT(headersForFoo(), ElementsAre(physicalHeader("exporter.h"),
+                                           physicalHeader("private.h")));
+}
+
+TEST_F(HeadersForSymbolTest, PreferPublicOverNameMatchOnPrivate) {
+  Inputs.Code = R"cpp(
+    #include "foo.h"
+  )cpp";
+  Inputs.ExtraFiles["foo.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "public.h"
+    struct foo {};
+  )cpp");
+  buildAST();
+  EXPECT_THAT(headersForFoo(), ElementsAre(Header(StringRef("\"public.h\"")),
+                                           physicalHeader("foo.h")));
+}
 } // namespace
 } // namespace clang::include_cleaner

diff  --git a/clang-tools-extra/include-cleaner/unittests/LocateSymbolTest.cpp b/clang-tools-extra/include-cleaner/unittests/LocateSymbolTest.cpp
index 5d2ec3e92bac0..16a185269b5d9 100644
--- a/clang-tools-extra/include-cleaner/unittests/LocateSymbolTest.cpp
+++ b/clang-tools-extra/include-cleaner/unittests/LocateSymbolTest.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "AnalysisInternal.h"
+#include "TypesInternal.h"
 #include "clang-include-cleaner/Types.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
@@ -20,6 +21,7 @@
 #include "gtest/gtest.h"
 #include <cstddef>
 #include <memory>
+#include <tuple>
 #include <unordered_map>
 #include <utility>
 #include <variant>
@@ -27,8 +29,11 @@
 
 namespace clang::include_cleaner {
 namespace {
+using testing::Each;
 using testing::ElementsAre;
 using testing::ElementsAreArray;
+using testing::Eq;
+using testing::Field;
 using testing::Pair;
 using testing::UnorderedElementsAre;
 
@@ -55,6 +60,10 @@ struct LocateExample {
       llvm::StringRef NameToFind;
       const NamedDecl *Out = nullptr;
       bool VisitNamedDecl(const NamedDecl *ND) {
+        // Skip the templated decls, as they have the same name and matches in
+        // this file care about the outer template name.
+        if (auto *TD = ND->getDescribedTemplate())
+          ND = TD;
         if (ND->getName() == NameToFind) {
           EXPECT_TRUE(Out == nullptr || Out == ND->getCanonicalDecl())
               << "Found multiple matches for " << NameToFind;
@@ -125,5 +134,57 @@ TEST(LocateSymbol, Macros) {
               ElementsAreArray(Test.points()));
 }
 
+MATCHER_P2(HintedSymbol, Symbol, Hint, "") {
+  return std::tie(arg.Hint, arg) == std::tie(Hint, Symbol);
+}
+TEST(LocateSymbol, CompleteSymbolHint) {
+  {
+    // stdlib symbols are always complete.
+    LocateExample Test("namespace std { struct vector; }");
+    EXPECT_THAT(locateSymbol(Test.findDecl("vector")),
+                ElementsAre(HintedSymbol(
+                    *tooling::stdlib::Symbol::named("std::", "vector"),
+                    Hints::CompleteSymbol)));
+  }
+  {
+    // macros are always complete.
+    LocateExample Test("#define ^FOO");
+    EXPECT_THAT(locateSymbol(Test.findMacro("FOO")),
+                ElementsAre(HintedSymbol(Test.points().front(),
+                                         Hints::CompleteSymbol)));
+  }
+  {
+    // Completeness is only absent in cases that matters.
+    const llvm::StringLiteral Cases[] = {
+        "struct ^foo; struct ^foo {};",
+        "template <typename> struct ^foo; template <typename> struct ^foo {};",
+        "template <typename> void ^foo(); template <typename> void ^foo() {};",
+    };
+    for (auto &Case : Cases) {
+      SCOPED_TRACE(Case);
+      LocateExample Test(Case);
+      EXPECT_THAT(locateSymbol(Test.findDecl("foo")),
+                  ElementsAre(HintedSymbol(Test.points().front(), Hints::None),
+                              HintedSymbol(Test.points().back(),
+                                           Hints::CompleteSymbol)));
+    }
+  }
+  {
+    // All declarations should be marked as complete in cases that a definition
+    // is not usually needed.
+    const llvm::StringLiteral Cases[] = {
+        "void foo(); void foo() {}",
+        "extern int foo; int foo;",
+    };
+    for (auto &Case : Cases) {
+      SCOPED_TRACE(Case);
+      LocateExample Test(Case);
+      EXPECT_THAT(locateSymbol(Test.findDecl("foo")),
+                  Each(Field(&Hinted<SymbolLocation>::Hint,
+                             Eq(Hints::CompleteSymbol))));
+    }
+  }
+}
+
 } // namespace
 } // namespace clang::include_cleaner


        


More information about the cfe-commits mailing list