[clang] 51f1ae5 - [clangd] Add new IncludeDirective to IncludeHeaderWithReferences

David Goldman via cfe-commits cfe-commits at lists.llvm.org
Tue Dec 6 10:48:06 PST 2022


Author: David Goldman
Date: 2022-12-06T13:47:07-05:00
New Revision: 51f1ae52b0c92a9783e7df328d05b1f95dca74d1

URL: https://github.com/llvm/llvm-project/commit/51f1ae52b0c92a9783e7df328d05b1f95dca74d1
DIFF: https://github.com/llvm/llvm-project/commit/51f1ae52b0c92a9783e7df328d05b1f95dca74d1.diff

LOG: [clangd] Add new IncludeDirective to IncludeHeaderWithReferences

The IncludeDirective contains both Include (the current behavior) and Import,
which we can use in the future to provide #import suggestions for
Objective-C files/symbols.

Differential Revision: https://reviews.llvm.org/D128457

Added: 
    

Modified: 
    clang-tools-extra/clangd/CodeComplete.cpp
    clang-tools-extra/clangd/Headers.cpp
    clang-tools-extra/clangd/Headers.h
    clang-tools-extra/clangd/IncludeFixer.cpp
    clang-tools-extra/clangd/index/Merge.cpp
    clang-tools-extra/clangd/index/Serialization.cpp
    clang-tools-extra/clangd/index/Symbol.h
    clang-tools-extra/clangd/index/SymbolCollector.cpp
    clang-tools-extra/clangd/index/SymbolCollector.h
    clang-tools-extra/clangd/index/YAMLSerialization.cpp
    clang-tools-extra/clangd/index/remote/Index.proto
    clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp
    clang-tools-extra/clangd/test/index-serialization/Inputs/sample.idx
    clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
    clang-tools-extra/clangd/unittests/DiagnosticsTests.cpp
    clang-tools-extra/clangd/unittests/IndexTests.cpp
    clang-tools-extra/clangd/unittests/SerializationTests.cpp
    clang/include/clang/Tooling/Inclusions/HeaderAnalysis.h
    clang/lib/Tooling/Inclusions/HeaderAnalysis.cpp
    clang/unittests/Tooling/HeaderAnalysisTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/clangd/CodeComplete.cpp b/clang-tools-extra/clangd/CodeComplete.cpp
index 33088e08387fb..32a80ac054cbe 100644
--- a/clang-tools-extra/clangd/CodeComplete.cpp
+++ b/clang-tools-extra/clangd/CodeComplete.cpp
@@ -191,7 +191,7 @@ struct CompletionCandidate {
   const CodeCompletionResult *SemaResult = nullptr;
   const Symbol *IndexResult = nullptr;
   const RawIdentifier *IdentifierResult = nullptr;
-  llvm::SmallVector<llvm::StringRef, 1> RankedIncludeHeaders;
+  llvm::SmallVector<SymbolInclude, 1> RankedIncludeHeaders;
 
   // Returns a token identifying the overload set this is part of.
   // 0 indicates it's not part of any overload set.
@@ -267,7 +267,11 @@ struct CompletionCandidate {
         if (SM.isInMainFile(SM.getExpansionLoc(RD->getBeginLoc())))
           return std::nullopt;
     }
-    return RankedIncludeHeaders[0];
+    for (const auto &Inc : RankedIncludeHeaders)
+      // FIXME: We should support #import directives here.
+      if ((Inc.Directive & clang::clangd::Symbol::Include) != 0)
+        return Inc.Header;
+    return None;
   }
 
   using Bundle = llvm::SmallVector<CompletionCandidate, 4>;
@@ -383,7 +387,11 @@ struct CodeCompletionBuilder {
     bool ShouldInsert = C.headerToInsertIfAllowed(Opts).has_value();
     // Calculate include paths and edits for all possible headers.
     for (const auto &Inc : C.RankedIncludeHeaders) {
-      if (auto ToInclude = Inserted(Inc)) {
+      // FIXME: We should support #import directives here.
+      if ((Inc.Directive & clang::clangd::Symbol::Include) == 0)
+        continue;
+
+      if (auto ToInclude = Inserted(Inc.Header)) {
         CodeCompletion::IncludeCandidate Include;
         Include.Header = ToInclude->first;
         if (ToInclude->second && ShouldInsert)
@@ -392,7 +400,7 @@ struct CodeCompletionBuilder {
       } else
         log("Failed to generate include insertion edits for adding header "
             "(FileURI='{0}', IncludeHeader='{1}') into {2}: {3}",
-            C.IndexResult->CanonicalDeclaration.FileURI, Inc, FileName,
+            C.IndexResult->CanonicalDeclaration.FileURI, Inc.Header, FileName,
             ToInclude.takeError());
     }
     // Prefer includes that do not need edits (i.e. already exist).

diff  --git a/clang-tools-extra/clangd/Headers.cpp b/clang-tools-extra/clangd/Headers.cpp
index a531197290702..95813894d5aa3 100644
--- a/clang-tools-extra/clangd/Headers.cpp
+++ b/clang-tools-extra/clangd/Headers.cpp
@@ -211,7 +211,7 @@ llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
   return HeaderFile{std::move(*Resolved), /*Verbatim=*/false};
 }
 
-llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym) {
+llvm::SmallVector<SymbolInclude, 1> getRankedIncludes(const Symbol &Sym) {
   auto Includes = Sym.IncludeHeaders;
   // Sort in descending order by reference count and header length.
   llvm::sort(Includes, [](const Symbol::IncludeHeaderWithReferences &LHS,
@@ -220,9 +220,9 @@ llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym) {
       return LHS.IncludeHeader.size() < RHS.IncludeHeader.size();
     return LHS.References > RHS.References;
   });
-  llvm::SmallVector<llvm::StringRef, 1> Headers;
+  llvm::SmallVector<SymbolInclude, 1> Headers;
   for (const auto &Include : Includes)
-    Headers.push_back(Include.IncludeHeader);
+    Headers.push_back({Include.IncludeHeader, Include.supportedDirectives()});
   return Headers;
 }
 

diff  --git a/clang-tools-extra/clangd/Headers.h b/clang-tools-extra/clangd/Headers.h
index 72a75de5a0370..0901a23f4dd83 100644
--- a/clang-tools-extra/clangd/Headers.h
+++ b/clang-tools-extra/clangd/Headers.h
@@ -45,6 +45,15 @@ struct HeaderFile {
   bool valid() const;
 };
 
+/// A header and directives as stored in a Symbol.
+struct SymbolInclude {
+  /// The header to include. This is either a URI or a verbatim include which is
+  /// quoted with <> or "".
+  llvm::StringRef Header;
+  /// The include directive to use, e.g. #import or #include.
+  Symbol::IncludeDirective Directive;
+};
+
 /// Creates a `HeaderFile` from \p Header which can be either a URI or a literal
 /// include.
 llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
@@ -52,7 +61,7 @@ llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
 
 // Returns include headers for \p Sym sorted by popularity. If two headers are
 // equally popular, prefer the shorter one.
-llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym);
+llvm::SmallVector<SymbolInclude, 1> getRankedIncludes(const Symbol &Sym);
 
 // An #include directive that we found in the main file.
 struct Inclusion {

diff  --git a/clang-tools-extra/clangd/IncludeFixer.cpp b/clang-tools-extra/clangd/IncludeFixer.cpp
index 4332768eff83b..5ba5a4e45ba04 100644
--- a/clang-tools-extra/clangd/IncludeFixer.cpp
+++ b/clang-tools-extra/clangd/IncludeFixer.cpp
@@ -317,7 +317,10 @@ std::vector<Fix> IncludeFixer::fixesForSymbols(const SymbolSlab &Syms) const {
   llvm::StringSet<> InsertedHeaders;
   for (const auto &Sym : Syms) {
     for (const auto &Inc : getRankedIncludes(Sym)) {
-      if (auto ToInclude = Inserted(Sym, Inc)) {
+      // FIXME: We should support #import directives here.
+      if ((Inc.Directive & clang::clangd::Symbol::Include) == 0)
+        continue;
+      if (auto ToInclude = Inserted(Sym, Inc.Header)) {
         if (ToInclude->second) {
           if (!InsertedHeaders.try_emplace(ToInclude->first).second)
             continue;
@@ -326,8 +329,8 @@ std::vector<Fix> IncludeFixer::fixesForSymbols(const SymbolSlab &Syms) const {
             Fixes.push_back(std::move(*Fix));
         }
       } else {
-        vlog("Failed to calculate include insertion for {0} into {1}: {2}", Inc,
-             File, ToInclude.takeError());
+        vlog("Failed to calculate include insertion for {0} into {1}: {2}",
+             Inc.Header, File, ToInclude.takeError());
       }
     }
   }

diff  --git a/clang-tools-extra/clangd/index/Merge.cpp b/clang-tools-extra/clangd/index/Merge.cpp
index 0d15dfcb1f252..9687b36252e12 100644
--- a/clang-tools-extra/clangd/index/Merge.cpp
+++ b/clang-tools-extra/clangd/index/Merge.cpp
@@ -248,11 +248,13 @@ Symbol mergeSymbol(const Symbol &L, const Symbol &R) {
       if (SI.IncludeHeader == OI.IncludeHeader) {
         Found = true;
         SI.References += OI.References;
+        SI.SupportedDirectives |= OI.SupportedDirectives;
         break;
       }
     }
     if (!Found && MergeIncludes)
-      S.IncludeHeaders.emplace_back(OI.IncludeHeader, OI.References);
+      S.IncludeHeaders.emplace_back(OI.IncludeHeader, OI.References,
+                                    OI.supportedDirectives());
   }
 
   S.Origin |= O.Origin | SymbolOrigin::Merge;

diff  --git a/clang-tools-extra/clangd/index/Serialization.cpp b/clang-tools-extra/clangd/index/Serialization.cpp
index 5a00fbf40d8a6..aaa1d517370fb 100644
--- a/clang-tools-extra/clangd/index/Serialization.cpp
+++ b/clang-tools-extra/clangd/index/Serialization.cpp
@@ -331,7 +331,7 @@ void writeSymbol(const Symbol &Sym, const StringTableOut &Strings,
 
   auto WriteInclude = [&](const Symbol::IncludeHeaderWithReferences &Include) {
     writeVar(Strings.index(Include.IncludeHeader), OS);
-    writeVar(Include.References, OS);
+    writeVar((Include.References << 2) | Include.SupportedDirectives, OS);
   };
   writeVar(Sym.IncludeHeaders.size(), OS);
   for (const auto &Include : Sym.IncludeHeaders)
@@ -361,7 +361,9 @@ Symbol readSymbol(Reader &Data, llvm::ArrayRef<llvm::StringRef> Strings,
     return Sym;
   for (auto &I : Sym.IncludeHeaders) {
     I.IncludeHeader = Data.consumeString(Strings);
-    I.References = Data.consumeVar();
+    uint32_t RefsWithDirectives = Data.consumeVar();
+    I.References = RefsWithDirectives >> 2;
+    I.SupportedDirectives = RefsWithDirectives & 0x3;
   }
   return Sym;
 }
@@ -455,7 +457,7 @@ readCompileCommand(Reader CmdReader, llvm::ArrayRef<llvm::StringRef> Strings) {
 // The current versioning scheme is simple - non-current versions are rejected.
 // If you make a breaking change, bump this version number to invalidate stored
 // data. Later we may want to support some backward compatibility.
-constexpr static uint32_t Version = 17;
+constexpr static uint32_t Version = 18;
 
 llvm::Expected<IndexFileIn> readRIFF(llvm::StringRef Data,
                                      SymbolOrigin Origin) {

diff  --git a/clang-tools-extra/clangd/index/Symbol.h b/clang-tools-extra/clangd/index/Symbol.h
index a8333a8fe1358..1aa5265299231 100644
--- a/clang-tools-extra/clangd/index/Symbol.h
+++ b/clang-tools-extra/clangd/index/Symbol.h
@@ -13,12 +13,15 @@
 #include "index/SymbolLocation.h"
 #include "index/SymbolOrigin.h"
 #include "clang/Index/IndexSymbol.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/StringSaver.h"
 
 namespace clang {
 namespace clangd {
 
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+
 /// The class presents a C++ symbol, e.g. class, function.
 ///
 /// WARNING: Symbols do not own much of their underlying data - typically
@@ -84,12 +87,24 @@ struct Symbol {
   /// Only set when the symbol is indexed for completion.
   llvm::StringRef Type;
 
+  enum IncludeDirective : uint8_t {
+    Invalid = 0,
+    /// `#include "header.h"`
+    Include = 1,
+    /// `#import "header.h"`
+    Import = 2,
+
+    LLVM_MARK_AS_BITMASK_ENUM(Import)
+  };
+
   struct IncludeHeaderWithReferences {
     IncludeHeaderWithReferences() = default;
 
     IncludeHeaderWithReferences(llvm::StringRef IncludeHeader,
-                                unsigned References)
-        : IncludeHeader(IncludeHeader), References(References) {}
+                                uint32_t References,
+                                IncludeDirective SupportedDirectives)
+        : IncludeHeader(IncludeHeader), References(References),
+          SupportedDirectives(SupportedDirectives) {}
 
     /// This can be either a URI of the header to be #include'd
     /// for this symbol, or a literal header quoted with <> or "" that is
@@ -101,7 +116,14 @@ struct Symbol {
     llvm::StringRef IncludeHeader = "";
     /// The number of translation units that reference this symbol and include
     /// this header. This number is only meaningful if aggregated in an index.
-    unsigned References = 0;
+    uint32_t References : 30;
+    /// Bitfield of supported directives (IncludeDirective) that can be used
+    /// when including this header.
+    uint32_t SupportedDirectives : 2;
+
+    IncludeDirective supportedDirectives() const {
+      return static_cast<IncludeDirective>(SupportedDirectives);
+    }
   };
   /// One Symbol can potentially be included via 
diff erent headers.
   ///   - If we haven't seen a definition, this covers all declarations.

diff  --git a/clang-tools-extra/clangd/index/SymbolCollector.cpp b/clang-tools-extra/clangd/index/SymbolCollector.cpp
index 54dfde4342175..362959d4ce90b 100644
--- a/clang-tools-extra/clangd/index/SymbolCollector.cpp
+++ b/clang-tools-extra/clangd/index/SymbolCollector.cpp
@@ -76,7 +76,7 @@ bool isPrivateProtoDecl(const NamedDecl &ND) {
 // We only collect #include paths for symbols that are suitable for global code
 // completion, except for namespaces since #include path for a namespace is hard
 // to define.
-bool shouldCollectIncludePath(index::SymbolKind Kind) {
+Symbol::IncludeDirective shouldCollectIncludePath(index::SymbolKind Kind) {
   using SK = index::SymbolKind;
   switch (Kind) {
   case SK::Macro:
@@ -90,9 +90,11 @@ bool shouldCollectIncludePath(index::SymbolKind Kind) {
   case SK::Variable:
   case SK::EnumConstant:
   case SK::Concept:
-    return true;
+    return Symbol::Include | Symbol::Import;
+  case SK::Protocol:
+    return Symbol::Import;
   default:
-    return false;
+    return Symbol::Invalid;
   }
 }
 
@@ -805,12 +807,12 @@ void SymbolCollector::processRelations(
 }
 
 void SymbolCollector::setIncludeLocation(const Symbol &S, SourceLocation Loc) {
-  if (Opts.CollectIncludePath)
-    if (shouldCollectIncludePath(S.SymInfo.Kind))
-      // Use the expansion location to get the #include header since this is
-      // where the symbol is exposed.
-      IncludeFiles[S.ID] =
-          PP->getSourceManager().getDecomposedExpansionLoc(Loc).first;
+  if (Opts.CollectIncludePath &&
+      shouldCollectIncludePath(S.SymInfo.Kind) != Symbol::Invalid)
+    // Use the expansion location to get the #include header since this is
+    // where the symbol is exposed.
+    IncludeFiles[S.ID] =
+        PP->getSourceManager().getDecomposedExpansionLoc(Loc).first;
 }
 
 void SymbolCollector::finish() {
@@ -835,11 +837,12 @@ void SymbolCollector::finish() {
             Symbols.erase(ID);
     }
   }
+  llvm::DenseMap<FileID, bool> FileToContainsImportsOrObjC;
   // Fill in IncludeHeaders.
   // We delay this until end of TU so header guards are all resolved.
   llvm::SmallString<128> QName;
-  for (const auto &Entry : IncludeFiles) {
-    if (const Symbol *S = Symbols.find(Entry.first)) {
+  for (const auto &[SID, FID] : IncludeFiles) {
+    if (const Symbol *S = Symbols.find(SID)) {
       llvm::StringRef IncludeHeader;
       // Look for an overridden include header for this symbol specifically.
       if (Opts.Includes) {
@@ -856,19 +859,36 @@ void SymbolCollector::finish() {
       }
       // Otherwise find the approprate include header for the defining file.
       if (IncludeHeader.empty())
-        IncludeHeader = HeaderFileURIs->getIncludeHeader(Entry.second);
+        IncludeHeader = HeaderFileURIs->getIncludeHeader(FID);
 
       // Symbols in slabs aren't mutable, insert() has to walk all the strings
       if (!IncludeHeader.empty()) {
-        Symbol NewSym = *S;
-        NewSym.IncludeHeaders.push_back({IncludeHeader, 1});
-        Symbols.insert(NewSym);
+        Symbol::IncludeDirective Directives = Symbol::Invalid;
+        auto CollectDirectives = shouldCollectIncludePath(S->SymInfo.Kind);
+        if ((CollectDirectives & Symbol::Include) != 0)
+          Directives |= Symbol::Include;
+        // Only allow #import for symbols from ObjC-like files.
+        if ((CollectDirectives & Symbol::Import) != 0) {
+          auto [It, Inserted] = FileToContainsImportsOrObjC.try_emplace(FID);
+          if (Inserted)
+            It->second = FilesWithObjCConstructs.contains(FID) ||
+                         tooling::codeContainsImports(
+                             ASTCtx->getSourceManager().getBufferData(FID));
+          if (It->second)
+            Directives |= Symbol::Import;
+        }
+        if (Directives != Symbol::Invalid) {
+          Symbol NewSym = *S;
+          NewSym.IncludeHeaders.push_back({IncludeHeader, 1, Directives});
+          Symbols.insert(NewSym);
+        }
       }
     }
   }
 
   ReferencedSymbols.clear();
   IncludeFiles.clear();
+  FilesWithObjCConstructs.clear();
 }
 
 const Symbol *SymbolCollector::addDeclaration(const NamedDecl &ND, SymbolID ID,
@@ -896,7 +916,8 @@ const Symbol *SymbolCollector::addDeclaration(const NamedDecl &ND, SymbolID ID,
   auto Loc = nameLocation(ND, SM);
   assert(Loc.isValid() && "Invalid source location for NamedDecl");
   // FIXME: use the result to filter out symbols.
-  shouldIndexFile(SM.getFileID(Loc));
+  auto FID = SM.getFileID(Loc);
+  shouldIndexFile(FID);
   if (auto DeclLoc = getTokenLocation(Loc))
     S.CanonicalDeclaration = *DeclLoc;
 
@@ -940,6 +961,8 @@ const Symbol *SymbolCollector::addDeclaration(const NamedDecl &ND, SymbolID ID,
 
   Symbols.insert(S);
   setIncludeLocation(S, ND.getLocation());
+  if (S.SymInfo.Lang == index::SymbolLanguage::ObjC)
+    FilesWithObjCConstructs.insert(FID);
   return Symbols.find(S.ID);
 }
 

diff  --git a/clang-tools-extra/clangd/index/SymbolCollector.h b/clang-tools-extra/clangd/index/SymbolCollector.h
index 1c6205a4022ca..6e2998f2d035d 100644
--- a/clang-tools-extra/clangd/index/SymbolCollector.h
+++ b/clang-tools-extra/clangd/index/SymbolCollector.h
@@ -153,6 +153,9 @@ class SymbolCollector : public index::IndexDataConsumer {
   // File IDs for Symbol.IncludeHeaders.
   // The final spelling is calculated in finish().
   llvm::DenseMap<SymbolID, FileID> IncludeFiles;
+  // Files which contain ObjC symbols.
+  // This is finalized and used in finish().
+  llvm::DenseSet<FileID> FilesWithObjCConstructs;
   void setIncludeLocation(const Symbol &S, SourceLocation);
   // Indexed macros, to be erased if they turned out to be include guards.
   llvm::DenseSet<const IdentifierInfo *> IndexedMacros;

diff  --git a/clang-tools-extra/clangd/index/YAMLSerialization.cpp b/clang-tools-extra/clangd/index/YAMLSerialization.cpp
index 1ac74338298a8..593ce30501b04 100644
--- a/clang-tools-extra/clangd/index/YAMLSerialization.cpp
+++ b/clang-tools-extra/clangd/index/YAMLSerialization.cpp
@@ -28,8 +28,13 @@
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 
+namespace {
+struct YIncludeHeaderWithReferences;
+}
+
 LLVM_YAML_IS_SEQUENCE_VECTOR(clang::clangd::Symbol::IncludeHeaderWithReferences)
 LLVM_YAML_IS_SEQUENCE_VECTOR(clang::clangd::Ref)
+LLVM_YAML_IS_SEQUENCE_VECTOR(YIncludeHeaderWithReferences)
 
 namespace {
 using RefBundle =
@@ -48,6 +53,21 @@ struct YPosition {
   uint32_t Line;
   uint32_t Column;
 };
+// A class helps YAML to serialize the IncludeHeaderWithReferences as YAMLIO
+// can't directly map bitfields.
+struct YIncludeHeaderWithReferences {
+  llvm::StringRef IncludeHeader;
+  uint32_t References;
+  clang::clangd::Symbol::IncludeDirective SupportedDirectives;
+
+  YIncludeHeaderWithReferences() = default;
+
+  YIncludeHeaderWithReferences(
+      llvm::StringRef IncludeHeader, uint32_t References,
+      clang::clangd::Symbol::IncludeDirective SupportedDirectives)
+      : IncludeHeader(IncludeHeader), References(References),
+        SupportedDirectives(SupportedDirectives) {}
+};
 
 // avoid ODR violation of specialization for non-owned CompileCommand
 struct CompileCommandYAML : clang::tooling::CompileCommand {};
@@ -165,13 +185,40 @@ template <> struct MappingTraits<SymbolInfo> {
   }
 };
 
-template <>
-struct MappingTraits<clang::clangd::Symbol::IncludeHeaderWithReferences> {
-  static void mapping(IO &IO,
-                      clang::clangd::Symbol::IncludeHeaderWithReferences &Inc) {
+template <> struct ScalarBitSetTraits<clang::clangd::Symbol::IncludeDirective> {
+  static void bitset(IO &IO, clang::clangd::Symbol::IncludeDirective &Value) {
+    IO.bitSetCase(Value, "Include", clang::clangd::Symbol::Include);
+    IO.bitSetCase(Value, "Import", clang::clangd::Symbol::Import);
+  }
+};
+
+template <> struct MappingTraits<YIncludeHeaderWithReferences> {
+  static void mapping(IO &IO, YIncludeHeaderWithReferences &Inc) {
     IO.mapRequired("Header", Inc.IncludeHeader);
     IO.mapRequired("References", Inc.References);
+    IO.mapOptional("Directives", Inc.SupportedDirectives,
+                   clang::clangd::Symbol::Include);
+  }
+};
+
+struct NormalizedIncludeHeaders {
+  using IncludeHeader = clang::clangd::Symbol::IncludeHeaderWithReferences;
+  NormalizedIncludeHeaders(IO &) {}
+  NormalizedIncludeHeaders(
+      IO &, const llvm::SmallVector<IncludeHeader, 1> &IncludeHeaders) {
+    for (auto &I : IncludeHeaders) {
+      Headers.emplace_back(I.IncludeHeader, I.References,
+                           I.supportedDirectives());
+    }
+  }
+
+  llvm::SmallVector<IncludeHeader, 1> denormalize(IO &) {
+    llvm::SmallVector<IncludeHeader, 1> Result;
+    for (auto &H : Headers)
+      Result.emplace_back(H.IncludeHeader, H.References, H.SupportedDirectives);
+    return Result;
   }
+  llvm::SmallVector<YIncludeHeaderWithReferences, 1> Headers;
 };
 
 template <> struct MappingTraits<Symbol> {
@@ -179,6 +226,10 @@ template <> struct MappingTraits<Symbol> {
     MappingNormalization<NormalizedSymbolID, SymbolID> NSymbolID(IO, Sym.ID);
     MappingNormalization<NormalizedSymbolFlag, Symbol::SymbolFlag> NSymbolFlag(
         IO, Sym.Flags);
+    MappingNormalization<
+        NormalizedIncludeHeaders,
+        llvm::SmallVector<Symbol::IncludeHeaderWithReferences, 1>>
+        NIncludeHeaders(IO, Sym.IncludeHeaders);
     IO.mapRequired("ID", NSymbolID->HexString);
     IO.mapRequired("Name", Sym.Name);
     IO.mapRequired("Scope", Sym.Scope);
@@ -195,7 +246,7 @@ template <> struct MappingTraits<Symbol> {
     IO.mapOptional("Documentation", Sym.Documentation);
     IO.mapOptional("ReturnType", Sym.ReturnType);
     IO.mapOptional("Type", Sym.Type);
-    IO.mapOptional("IncludeHeaders", Sym.IncludeHeaders);
+    IO.mapOptional("IncludeHeaders", NIncludeHeaders->Headers);
   }
 };
 

diff  --git a/clang-tools-extra/clangd/index/remote/Index.proto b/clang-tools-extra/clangd/index/remote/Index.proto
index 062feafeb7355..3072299d8f345 100644
--- a/clang-tools-extra/clangd/index/remote/Index.proto
+++ b/clang-tools-extra/clangd/index/remote/Index.proto
@@ -107,6 +107,7 @@ message Position {
 message HeaderWithReferences {
   optional string header = 1;
   optional uint32 references = 2;
+  optional uint32 supported_directives = 3;
 }
 
 message RelationsRequest {

diff  --git a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp
index 4271fd79db02c..7e31ada18a657 100644
--- a/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp
+++ b/clang-tools-extra/clangd/index/remote/marshalling/Marshalling.cpp
@@ -406,6 +406,7 @@ llvm::Expected<HeaderWithReferences> Marshaller::toProtobuf(
     const clangd::Symbol::IncludeHeaderWithReferences &IncludeHeader) {
   HeaderWithReferences Result;
   Result.set_references(IncludeHeader.References);
+  Result.set_supported_directives(IncludeHeader.SupportedDirectives);
   const std::string Header = IncludeHeader.IncludeHeader.str();
   if (isLiteralInclude(Header)) {
     Result.set_header(Header);
@@ -427,8 +428,12 @@ Marshaller::fromProtobuf(const HeaderWithReferences &Message) {
       return URIString.takeError();
     Header = *URIString;
   }
-  return clangd::Symbol::IncludeHeaderWithReferences{Strings.save(Header),
-                                                     Message.references()};
+  auto Directives = clangd::Symbol::IncludeDirective::Include;
+  if (Message.has_supported_directives())
+    Directives = static_cast<clangd::Symbol::IncludeDirective>(
+        Message.supported_directives());
+  return clangd::Symbol::IncludeHeaderWithReferences{
+      Strings.save(Header), Message.references(), Directives};
 }
 
 } // namespace remote

diff  --git a/clang-tools-extra/clangd/test/index-serialization/Inputs/sample.idx b/clang-tools-extra/clangd/test/index-serialization/Inputs/sample.idx
index 91e9e5f5dbb11..b59849472d57b 100644
Binary files a/clang-tools-extra/clangd/test/index-serialization/Inputs/sample.idx and b/clang-tools-extra/clangd/test/index-serialization/Inputs/sample.idx 
diff er

diff  --git a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
index a7ae1112c533f..7c1a2d39ccbff 100644
--- a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
+++ b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
@@ -79,6 +79,10 @@ MATCHER_P(insertInclude, IncludeHeader, "") {
   return !arg.Includes.empty() && arg.Includes[0].Header == IncludeHeader &&
          bool(arg.Includes[0].Insertion);
 }
+MATCHER_P(insertIncludeText, InsertedText, "") {
+  return !arg.Includes.empty() && arg.Includes[0].Insertion &&
+         arg.Includes[0].Insertion->newText == InsertedText;
+}
 MATCHER(insertInclude, "") {
   return !arg.Includes.empty() && bool(arg.Includes[0].Insertion);
 }
@@ -812,7 +816,7 @@ TEST(CompletionTest, IncludeInsertionPreprocessorIntegrationTests) {
 
   Symbol Sym = cls("ns::X");
   Sym.CanonicalDeclaration.FileURI = BarURI.c_str();
-  Sym.IncludeHeaders.emplace_back(BarURI, 1);
+  Sym.IncludeHeaders.emplace_back(BarURI, 1, Symbol::Include);
   // Shorten include path based on search directory and insert.
   Annotations Test("int main() { ns::^ }");
   TU.Code = Test.code().str();
@@ -843,8 +847,8 @@ TEST(CompletionTest, NoIncludeInsertionWhenDeclFoundInFile) {
   auto BarURI = URI::create(BarHeader).toString();
   SymX.CanonicalDeclaration.FileURI = BarURI.c_str();
   SymY.CanonicalDeclaration.FileURI = BarURI.c_str();
-  SymX.IncludeHeaders.emplace_back("<bar>", 1);
-  SymY.IncludeHeaders.emplace_back("<bar>", 1);
+  SymX.IncludeHeaders.emplace_back("<bar>", 1, Symbol::Include);
+  SymY.IncludeHeaders.emplace_back("<bar>", 1, Symbol::Include);
   // Shorten include path based on search directory and insert.
   auto Results = completions(R"cpp(
           namespace ns {
@@ -1867,7 +1871,7 @@ TEST(CompletionTest, OverloadBundling) {
   // Differences in header-to-insert suppress bundling.
   std::string DeclFile = URI::create(testPath("foo")).toString();
   NoArgsGFunc.CanonicalDeclaration.FileURI = DeclFile.c_str();
-  NoArgsGFunc.IncludeHeaders.emplace_back("<foo>", 1);
+  NoArgsGFunc.IncludeHeaders.emplace_back("<foo>", 1, Symbol::Include);
   EXPECT_THAT(
       completions(Context + "int y = GFunc^", {NoArgsGFunc}, Opts).Completions,
       UnorderedElementsAre(AllOf(named("GFuncC"), insertInclude("<foo>")),
@@ -1901,8 +1905,8 @@ TEST(CompletionTest, OverloadBundlingSameFileDifferentURI) {
   SymX.CanonicalDeclaration.FileURI = BarURI.c_str();
   SymY.CanonicalDeclaration.FileURI = BarURI.c_str();
   // The include header is 
diff erent, but really it's the same file.
-  SymX.IncludeHeaders.emplace_back("\"bar.h\"", 1);
-  SymY.IncludeHeaders.emplace_back(BarURI.c_str(), 1);
+  SymX.IncludeHeaders.emplace_back("\"bar.h\"", 1, Symbol::Include);
+  SymY.IncludeHeaders.emplace_back(BarURI.c_str(), 1, Symbol::Include);
 
   auto Results = completions("void f() { ::ns::^ }", {SymX, SymY}, Opts);
   // Expect both results are bundled, despite the 
diff erent-but-same
@@ -2699,8 +2703,8 @@ TEST(CompletionTest, InsertTheMostPopularHeader) {
   std::string DeclFile = URI::create(testPath("foo")).toString();
   Symbol Sym = func("Func");
   Sym.CanonicalDeclaration.FileURI = DeclFile.c_str();
-  Sym.IncludeHeaders.emplace_back("\"foo.h\"", 2);
-  Sym.IncludeHeaders.emplace_back("\"bar.h\"", 1000);
+  Sym.IncludeHeaders.emplace_back("\"foo.h\"", 2, Symbol::Include);
+  Sym.IncludeHeaders.emplace_back("\"bar.h\"", 1000, Symbol::Include);
 
   auto Results = completions("Fun^", {Sym}).Completions;
   assert(!Results.empty());
@@ -2708,6 +2712,30 @@ TEST(CompletionTest, InsertTheMostPopularHeader) {
   EXPECT_EQ(Results[0].Includes.size(), 2u);
 }
 
+TEST(CompletionTest, InsertIncludeOrImport) {
+  std::string DeclFile = URI::create(testPath("foo")).toString();
+  Symbol Sym = func("Func");
+  Sym.CanonicalDeclaration.FileURI = DeclFile.c_str();
+  Sym.IncludeHeaders.emplace_back("\"bar.h\"", 1000,
+                                  Symbol::Include | Symbol::Import);
+
+  auto Results = completions("Fun^", {Sym}).Completions;
+  assert(!Results.empty());
+  EXPECT_THAT(Results[0],
+              AllOf(named("Func"), insertIncludeText("#include \"bar.h\"\n")));
+
+  Results = completions("Fun^", {Sym}, {}, "Foo.m").Completions;
+  assert(!Results.empty());
+  // TODO: Once #import integration support is done this should be #import.
+  EXPECT_THAT(Results[0],
+              AllOf(named("Func"), insertIncludeText("#include \"bar.h\"\n")));
+
+  Sym.IncludeHeaders[0].SupportedDirectives = Symbol::Import;
+  Results = completions("Fun^", {Sym}).Completions;
+  assert(!Results.empty());
+  EXPECT_THAT(Results[0], AllOf(named("Func"), Not(insertInclude())));
+}
+
 TEST(CompletionTest, NoInsertIncludeIfOnePresent) {
   Annotations Test(R"cpp(
     #include "foo.h"
@@ -2719,8 +2747,8 @@ TEST(CompletionTest, NoInsertIncludeIfOnePresent) {
   std::string DeclFile = URI::create(testPath("foo")).toString();
   Symbol Sym = func("Func");
   Sym.CanonicalDeclaration.FileURI = DeclFile.c_str();
-  Sym.IncludeHeaders.emplace_back("\"foo.h\"", 2);
-  Sym.IncludeHeaders.emplace_back("\"bar.h\"", 1000);
+  Sym.IncludeHeaders.emplace_back("\"foo.h\"", 2, Symbol::Include);
+  Sym.IncludeHeaders.emplace_back("\"bar.h\"", 1000, Symbol::Include);
 
   EXPECT_THAT(completions(TU, Test.point(), {Sym}).Completions,
               UnorderedElementsAre(AllOf(named("Func"), hasInclude("\"foo.h\""),

diff  --git a/clang-tools-extra/clangd/unittests/DiagnosticsTests.cpp b/clang-tools-extra/clangd/unittests/DiagnosticsTests.cpp
index 0727fff464303..2b527142f0d26 100644
--- a/clang-tools-extra/clangd/unittests/DiagnosticsTests.cpp
+++ b/clang-tools-extra/clangd/unittests/DiagnosticsTests.cpp
@@ -1071,7 +1071,7 @@ buildIndexWithSymbol(llvm::ArrayRef<SymbolWithHeader> Syms) {
     Sym.Flags |= Symbol::IndexedForCodeCompletion;
     Sym.CanonicalDeclaration.FileURI = S.DeclaringFile.c_str();
     Sym.Definition.FileURI = S.DeclaringFile.c_str();
-    Sym.IncludeHeaders.emplace_back(S.IncludeHeader, 1);
+    Sym.IncludeHeaders.emplace_back(S.IncludeHeader, 1, Symbol::Include);
     Slab.insert(Sym);
   }
   return MemIndex::build(std::move(Slab).build(), RefSlab(), RelationSlab());
@@ -1129,7 +1129,7 @@ TEST(IncludeFixerTest, IncompleteEnum) {
   Symbol Sym = enm("X");
   Sym.Flags |= Symbol::IndexedForCodeCompletion;
   Sym.CanonicalDeclaration.FileURI = Sym.Definition.FileURI = "unittest:///x.h";
-  Sym.IncludeHeaders.emplace_back("\"x.h\"", 1);
+  Sym.IncludeHeaders.emplace_back("\"x.h\"", 1, Symbol::Include);
   SymbolSlab::Builder Slab;
   Slab.insert(Sym);
   auto Index =
@@ -1172,7 +1172,7 @@ int main() {
   Sym.Flags |= Symbol::IndexedForCodeCompletion;
   Sym.CanonicalDeclaration.FileURI = "unittest:///x.h";
   Sym.Definition.FileURI = "unittest:///x.cc";
-  Sym.IncludeHeaders.emplace_back("\"x.h\"", 1);
+  Sym.IncludeHeaders.emplace_back("\"x.h\"", 1, Symbol::Include);
 
   SymbolSlab::Builder Slab;
   Slab.insert(Sym);
@@ -1503,7 +1503,7 @@ TEST(IncludeFixerTest, CImplicitFunctionDecl) {
   Symbol Sym = func("foo");
   Sym.Flags |= Symbol::IndexedForCodeCompletion;
   Sym.CanonicalDeclaration.FileURI = "unittest:///foo.h";
-  Sym.IncludeHeaders.emplace_back("\"foo.h\"", 1);
+  Sym.IncludeHeaders.emplace_back("\"foo.h\"", 1, Symbol::Include);
 
   SymbolSlab::Builder Slab;
   Slab.insert(Sym);

diff  --git a/clang-tools-extra/clangd/unittests/IndexTests.cpp b/clang-tools-extra/clangd/unittests/IndexTests.cpp
index 5b9b3606779a2..658b4e200004e 100644
--- a/clang-tools-extra/clangd/unittests/IndexTests.cpp
+++ b/clang-tools-extra/clangd/unittests/IndexTests.cpp
@@ -567,9 +567,9 @@ TEST(MergeTest, MergeIncludesOnDifferentDefinitions) {
   L.Name = "left";
   R.Name = "right";
   L.ID = R.ID = SymbolID("hello");
-  L.IncludeHeaders.emplace_back("common", 1);
-  R.IncludeHeaders.emplace_back("common", 1);
-  R.IncludeHeaders.emplace_back("new", 1);
+  L.IncludeHeaders.emplace_back("common", 1, Symbol::Include);
+  R.IncludeHeaders.emplace_back("common", 1, Symbol::Include);
+  R.IncludeHeaders.emplace_back("new", 1, Symbol::Include);
 
   // Both have no definition.
   Symbol M = mergeSymbol(L, R);
@@ -615,7 +615,7 @@ TEST(MergeIndexTest, IncludeHeadersMerged) {
                     std::move(DynData), DynSize);
 
   SymbolSlab::Builder StaticB;
-  S.IncludeHeaders.push_back({"<header>", 0});
+  S.IncludeHeaders.push_back({"<header>", 0, Symbol::Include});
   StaticB.insert(S);
   auto StaticIndex =
       MemIndex::build(std::move(StaticB).build(), RefSlab(), RelationSlab());

diff  --git a/clang-tools-extra/clangd/unittests/SerializationTests.cpp b/clang-tools-extra/clangd/unittests/SerializationTests.cpp
index ae1914f303310..38e8612e3803e 100644
--- a/clang-tools-extra/clangd/unittests/SerializationTests.cpp
+++ b/clang-tools-extra/clangd/unittests/SerializationTests.cpp
@@ -53,8 +53,16 @@ ReturnType:    'int'
 IncludeHeaders:
   - Header:    'include1'
     References:    7
+    Directives:      [ Include ]
   - Header:    'include2'
     References:    3
+    Directives:      [ Import ]
+  - Header:    'include3'
+    References:    2
+    Directives:      [ Include, Import ]
+  - Header:    'include4'
+    References:    1
+    Directives:      [ ]
 ...
 ---
 !Symbol
@@ -114,8 +122,11 @@ Digest:          EED8F5EAF25C453C
 
 MATCHER_P(id, I, "") { return arg.ID == cantFail(SymbolID::fromStr(I)); }
 MATCHER_P(qName, Name, "") { return (arg.Scope + arg.Name).str() == Name; }
-MATCHER_P2(IncludeHeaderWithRef, IncludeHeader, References, "") {
-  return (arg.IncludeHeader == IncludeHeader) && (arg.References == References);
+MATCHER_P3(IncludeHeaderWithRefAndDirectives, IncludeHeader, References,
+           SupportedDirectives, "") {
+  return (arg.IncludeHeader == IncludeHeader) &&
+         (arg.References == References) &&
+         (arg.SupportedDirectives == SupportedDirectives);
 }
 
 auto readIndexFile(llvm::StringRef Text) {
@@ -148,9 +159,14 @@ TEST(SerializationTest, YAMLConversions) {
   EXPECT_EQ(static_cast<uint8_t>(Sym1.Flags), 129);
   EXPECT_TRUE(Sym1.Flags & Symbol::IndexedForCodeCompletion);
   EXPECT_FALSE(Sym1.Flags & Symbol::Deprecated);
-  EXPECT_THAT(Sym1.IncludeHeaders,
-              UnorderedElementsAre(IncludeHeaderWithRef("include1", 7u),
-                                   IncludeHeaderWithRef("include2", 3u)));
+  EXPECT_THAT(
+      Sym1.IncludeHeaders,
+      UnorderedElementsAre(
+          IncludeHeaderWithRefAndDirectives("include1", 7u, Symbol::Include),
+          IncludeHeaderWithRefAndDirectives("include2", 3u, Symbol::Import),
+          IncludeHeaderWithRefAndDirectives("include3", 2u,
+                                            Symbol::Include | Symbol::Import),
+          IncludeHeaderWithRefAndDirectives("include4", 1u, Symbol::Invalid)));
 
   EXPECT_THAT(Sym2, qName("clang::Foo2"));
   EXPECT_EQ(Sym2.Signature, "-sig");

diff  --git a/clang/include/clang/Tooling/Inclusions/HeaderAnalysis.h b/clang/include/clang/Tooling/Inclusions/HeaderAnalysis.h
index 3d47829296e91..31854ff6f59da 100644
--- a/clang/include/clang/Tooling/Inclusions/HeaderAnalysis.h
+++ b/clang/include/clang/Tooling/Inclusions/HeaderAnalysis.h
@@ -21,7 +21,7 @@ namespace tooling {
 /// Returns true if the given physical file is a self-contained header.
 ///
 /// A header is considered self-contained if
-//   - it has a proper header guard or has been #imported
+//   - it has a proper header guard or has been #imported or contains #import(s)
 //   - *and* it doesn't have a dont-include-me pattern.
 ///
 /// This function can be expensive as it may scan the source code to find out
@@ -29,6 +29,9 @@ namespace tooling {
 bool isSelfContainedHeader(const FileEntry *FE, const SourceManager &SM,
                            HeaderSearch &HeaderInfo);
 
+/// This scans the given source code to see if it contains #import(s).
+bool codeContainsImports(llvm::StringRef Code);
+
 /// If Text begins an Include-What-You-Use directive, returns it.
 /// Given "// IWYU pragma: keep", returns "keep".
 /// Input is a null-terminated char* as provided by SM.getCharacterData().

diff  --git a/clang/lib/Tooling/Inclusions/HeaderAnalysis.cpp b/clang/lib/Tooling/Inclusions/HeaderAnalysis.cpp
index ea9cfacc206e7..f2a15c2a568cf 100644
--- a/clang/lib/Tooling/Inclusions/HeaderAnalysis.cpp
+++ b/clang/lib/Tooling/Inclusions/HeaderAnalysis.cpp
@@ -37,8 +37,7 @@ bool isErrorAboutInclude(llvm::StringRef Line) {
 }
 
 // Heuristically headers that only want to be included via an umbrella.
-bool isDontIncludeMeHeader(llvm::MemoryBufferRef Buffer) {
-  StringRef Content = Buffer.getBuffer();
+bool isDontIncludeMeHeader(StringRef Content) {
   llvm::StringRef Line;
   // Only sniff up to 100 lines or 10KB.
   Content = Content.take_front(100 * 100);
@@ -50,19 +49,48 @@ bool isDontIncludeMeHeader(llvm::MemoryBufferRef Buffer) {
   return false;
 }
 
+bool isImportLine(llvm::StringRef Line) {
+  Line = Line.ltrim();
+  if (!Line.consume_front("#"))
+    return false;
+  Line = Line.ltrim();
+  return Line.startswith("import");
+}
+
+llvm::StringRef getFileContents(const FileEntry *FE, const SourceManager &SM) {
+  return const_cast<SourceManager &>(SM)
+      .getMemoryBufferForFileOrNone(FE)
+      .value_or(llvm::MemoryBufferRef())
+      .getBuffer();
+}
+
 } // namespace
 
 bool isSelfContainedHeader(const FileEntry *FE, const SourceManager &SM,
                            HeaderSearch &HeaderInfo) {
   assert(FE);
   if (!HeaderInfo.isFileMultipleIncludeGuarded(FE) &&
-      !HeaderInfo.hasFileBeenImported(FE))
+      !HeaderInfo.hasFileBeenImported(FE) &&
+      // Any header that contains #imports is supposed to be #import'd so no
+      // need to check for anything but the main-file.
+      (SM.getFileEntryForID(SM.getMainFileID()) != FE ||
+       !codeContainsImports(getFileContents(FE, SM))))
     return false;
   // This pattern indicates that a header can't be used without
   // particular preprocessor state, usually set up by another header.
-  return !isDontIncludeMeHeader(
-      const_cast<SourceManager &>(SM).getMemoryBufferForFileOrNone(FE).value_or(
-          llvm::MemoryBufferRef()));
+  return !isDontIncludeMeHeader(getFileContents(FE, SM));
+}
+
+bool codeContainsImports(llvm::StringRef Code) {
+  // Only sniff up to 100 lines or 10KB.
+  Code = Code.take_front(100 * 100);
+  llvm::StringRef Line;
+  for (unsigned I = 0; I < 100 && !Code.empty(); ++I) {
+    std::tie(Line, Code) = Code.split('\n');
+    if (isImportLine(Line))
+      return true;
+  }
+  return false;
 }
 
 llvm::Optional<StringRef> parseIWYUPragma(const char *Text) {

diff  --git a/clang/unittests/Tooling/HeaderAnalysisTest.cpp b/clang/unittests/Tooling/HeaderAnalysisTest.cpp
index a2096d00d7e50..186eb87b062d2 100644
--- a/clang/unittests/Tooling/HeaderAnalysisTest.cpp
+++ b/clang/unittests/Tooling/HeaderAnalysisTest.cpp
@@ -56,11 +56,42 @@ TEST(HeaderAnalysisTest, IsSelfContained) {
   EXPECT_TRUE(isSelfContainedHeader(FM.getFile("headerguard.h").get(), SM, HI));
   EXPECT_TRUE(isSelfContainedHeader(FM.getFile("pragmaonce.h").get(), SM, HI));
   EXPECT_TRUE(isSelfContainedHeader(FM.getFile("imported.h").get(), SM, HI));
+  EXPECT_TRUE(
+      isSelfContainedHeader(SM.getFileEntryForID(SM.getMainFileID()), SM, HI));
 
   EXPECT_FALSE(isSelfContainedHeader(FM.getFile("unguarded.h").get(), SM, HI));
   EXPECT_FALSE(isSelfContainedHeader(FM.getFile("bad.h").get(), SM, HI));
 }
 
+TEST(HeaderAnalysisTest, CodeContainsImports) {
+  EXPECT_TRUE(codeContainsImports(R"cpp(
+  #include "foo.h"
+  #import "NSFoo.h"
+
+  int main() {
+    foo();
+  }
+  )cpp"));
+
+  EXPECT_TRUE(codeContainsImports(R"cpp(
+  #include "foo.h"
+
+  int main() {
+    foo();
+  }
+
+  #import "NSFoo.h"
+  )cpp"));
+
+  EXPECT_FALSE(codeContainsImports(R"cpp(
+  #include "foo.h"
+
+  int main() {
+    foo();
+  }
+  )cpp"));
+}
+
 TEST(HeaderAnalysisTest, ParseIWYUPragma) {
   EXPECT_THAT(parseIWYUPragma("// IWYU pragma: keep"), ValueIs(Eq("keep")));
   EXPECT_THAT(parseIWYUPragma("// IWYU pragma:   keep  me\netc"),


        


More information about the cfe-commits mailing list